diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 883eb5a6a993..ee4f4d2f7f01 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -276,6 +276,7 @@ jobs: - ":trino-clickhouse" - ":trino-hive,:trino-orc" - ":trino-hive,:trino-parquet -P test-parquet" + - ":trino-hive -P test-failure-recovery" - ":trino-mongodb,:trino-kafka,:trino-elasticsearch" - ":trino-elasticsearch -P test-stats" - ":trino-redis" diff --git a/client/trino-client/src/main/java/io/trino/client/StageStats.java b/client/trino-client/src/main/java/io/trino/client/StageStats.java index e08dc927c2f5..d18b8d8846b9 100644 --- a/client/trino-client/src/main/java/io/trino/client/StageStats.java +++ b/client/trino-client/src/main/java/io/trino/client/StageStats.java @@ -40,6 +40,8 @@ public class StageStats private final long processedRows; private final long processedBytes; private final long physicalInputBytes; + private final int failedTasks; + private final boolean coordinatorOnly; private final List subStages; @JsonCreator @@ -57,6 +59,8 @@ public StageStats( @JsonProperty("processedRows") long processedRows, @JsonProperty("processedBytes") long processedBytes, @JsonProperty("physicalInputBytes") long physicalInputBytes, + @JsonProperty("failedTasks") int failedTasks, + @JsonProperty("coordinatorOnly") boolean coordinatorOnly, @JsonProperty("subStages") List subStages) { this.stageId = stageId; @@ -72,6 +76,8 @@ public StageStats( this.processedRows = processedRows; this.processedBytes = processedBytes; this.physicalInputBytes = physicalInputBytes; + this.failedTasks = failedTasks; + this.coordinatorOnly = coordinatorOnly; this.subStages = ImmutableList.copyOf(requireNonNull(subStages, "subStages is null")); } @@ -153,6 +159,18 @@ public long getPhysicalInputBytes() return physicalInputBytes; } + @JsonProperty + public int getFailedTasks() + { + return failedTasks; + } + + @JsonProperty + public boolean isCoordinatorOnly() + { + return coordinatorOnly; + } + @JsonProperty public List getSubStages() { @@ -175,6 +193,8 @@ public String toString() .add("processedRows", processedRows) .add("processedBytes", processedBytes) .add("physicalInputBytes", physicalInputBytes) + .add("failedTasks", failedTasks) + .add("coordinatorOnly", coordinatorOnly) .add("subStages", subStages) .toString(); } @@ -199,6 +219,8 @@ public static class Builder private long processedRows; private long processedBytes; private long physicalInputBytes; + private int failedTasks; + private boolean coordinatorOnly; private List subStages; private Builder() {} @@ -281,6 +303,18 @@ public Builder setPhysicalInputBytes(long physicalInputBytes) return this; } + public Builder setFailedTasks(int failedTasks) + { + this.failedTasks = failedTasks; + return this; + } + + public Builder setCoordinatorOnly(boolean coordinatorOnly) + { + this.coordinatorOnly = coordinatorOnly; + return this; + } + public Builder setSubStages(List subStages) { this.subStages = ImmutableList.copyOf(requireNonNull(subStages, "subStages is null")); @@ -303,6 +337,8 @@ public StageStats build() processedRows, processedBytes, physicalInputBytes, + failedTasks, + coordinatorOnly, subStages); } } diff --git a/core/trino-main/src/main/java/io/trino/FeaturesConfig.java b/core/trino-main/src/main/java/io/trino/FeaturesConfig.java index 81fc678d5f2c..00901d3a90fe 100644 --- a/core/trino-main/src/main/java/io/trino/FeaturesConfig.java +++ b/core/trino-main/src/main/java/io/trino/FeaturesConfig.java @@ -23,6 +23,7 @@ import io.airlift.units.DataSize; import io.airlift.units.Duration; import io.airlift.units.MaxDataSize; +import io.trino.operator.RetryPolicy; import io.trino.sql.analyzer.RegexLibrary; import javax.validation.constraints.DecimalMax; @@ -40,6 +41,7 @@ import static io.trino.sql.analyzer.RegexLibrary.JONI; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MINUTES; +import static java.util.concurrent.TimeUnit.SECONDS; @DefunctConfig({ "analyzer.experimental-syntax-enabled", @@ -145,6 +147,11 @@ public class FeaturesConfig private boolean disableSetPropertiesSecurityCheckForCreateDdl; private boolean incrementalHashArrayLoadFactorEnabled = true; + private RetryPolicy retryPolicy = RetryPolicy.NONE; + private int retryAttempts = 4; + private Duration retryInitialDelay = new Duration(10, SECONDS); + private Duration retryMaxDelay = new Duration(1, MINUTES); + public enum JoinReorderingStrategy { NONE, @@ -1108,4 +1115,58 @@ public FeaturesConfig setIncrementalHashArrayLoadFactorEnabled(boolean increment this.incrementalHashArrayLoadFactorEnabled = incrementalHashArrayLoadFactorEnabled; return this; } + + @NotNull + public RetryPolicy getRetryPolicy() + { + return retryPolicy; + } + + @Config("retry-policy") + public FeaturesConfig setRetryPolicy(RetryPolicy retryPolicy) + { + this.retryPolicy = retryPolicy; + return this; + } + + @Min(0) + public int getRetryAttempts() + { + return retryAttempts; + } + + @Config("retry-attempts") + public FeaturesConfig setRetryAttempts(int retryAttempts) + { + this.retryAttempts = retryAttempts; + return this; + } + + @NotNull + public Duration getRetryInitialDelay() + { + return retryInitialDelay; + } + + @Config("retry-initial-delay") + @ConfigDescription("Initial delay before initiating a retry attempt. Delay increases exponentially for each subsequent attempt up to 'retry_max_delay'") + public FeaturesConfig setRetryInitialDelay(Duration retryInitialDelay) + { + this.retryInitialDelay = retryInitialDelay; + return this; + } + + @NotNull + public Duration getRetryMaxDelay() + { + return retryMaxDelay; + } + + @Config("retry-max-delay") + @ConfigDescription("Maximum delay before initiating a retry attempt. Delay increases exponentially for each subsequent attempt starting from 'retry_initial_delay'") + public FeaturesConfig setRetryMaxDelay(Duration retryMaxDelay) + { + this.retryMaxDelay = retryMaxDelay; + return this; + } } diff --git a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java index 6ea7984bd5fe..fc2976d4b2fb 100644 --- a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java +++ b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java @@ -24,6 +24,7 @@ import io.trino.execution.scheduler.NodeSchedulerConfig; import io.trino.memory.MemoryManagerConfig; import io.trino.memory.NodeMemoryConfig; +import io.trino.operator.RetryPolicy; import io.trino.spi.TrinoException; import io.trino.spi.session.PropertyMetadata; @@ -37,6 +38,7 @@ import static io.trino.plugin.base.session.PropertyMetadataUtil.dataSizeProperty; import static io.trino.plugin.base.session.PropertyMetadataUtil.durationProperty; import static io.trino.spi.StandardErrorCode.INVALID_SESSION_PROPERTY; +import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.session.PropertyMetadata.booleanProperty; import static io.trino.spi.session.PropertyMetadata.doubleProperty; import static io.trino.spi.session.PropertyMetadata.enumProperty; @@ -143,6 +145,10 @@ public final class SystemSessionProperties public static final String LEGACY_CATALOG_ROLES = "legacy_catalog_roles"; public static final String INCREMENTAL_HASH_ARRAY_LOAD_FACTOR_ENABLED = "incremental_hash_array_load_factor_enabled"; public static final String MAX_PARTIAL_TOP_N_MEMORY = "max_partial_top_n_memory"; + public static final String RETRY_POLICY = "retry_policy"; + public static final String RETRY_ATTEMPTS = "retry_attempts"; + public static final String RETRY_INITIAL_DELAY = "retry_initial_delay"; + public static final String RETRY_MAX_DELAY = "retry_max_delay"; private final List> sessionProperties; @@ -665,6 +671,27 @@ public SystemSessionProperties( MAX_PARTIAL_TOP_N_MEMORY, "Max memory size for partial Top N aggregations. This can be turned off by setting it with '0'.", taskManagerConfig.getMaxPartialTopNMemory(), + false), + enumProperty( + RETRY_POLICY, + "Retry policy", + RetryPolicy.class, + featuresConfig.getRetryPolicy(), + false), + integerProperty( + RETRY_ATTEMPTS, + "Maximum number of retry attempts", + featuresConfig.getRetryAttempts(), + false), + durationProperty( + RETRY_INITIAL_DELAY, + "Initial delay before initiating a retry attempt. Delay increases exponentially for each subsequent attempt up to 'retry_max_delay'", + featuresConfig.getRetryInitialDelay(), + false), + durationProperty( + RETRY_MAX_DELAY, + "Maximum delay before initiating a retry attempt. Delay increases exponentially for each subsequent attempt starting from 'retry_initial_delay'", + featuresConfig.getRetryMaxDelay(), false)); } @@ -1183,4 +1210,33 @@ public static DataSize getMaxPartialTopNMemory(Session session) { return session.getSystemProperty(MAX_PARTIAL_TOP_N_MEMORY, DataSize.class); } + + public static RetryPolicy getRetryPolicy(Session session) + { + RetryPolicy retryPolicy = session.getSystemProperty(RETRY_POLICY, RetryPolicy.class); + if (retryPolicy != RetryPolicy.NONE) { + if (isEnableDynamicFiltering(session)) { + throw new TrinoException(NOT_SUPPORTED, "Dynamic filtering is not supported with automatic retries enabled"); + } + if (isDistributedSortEnabled(session)) { + throw new TrinoException(NOT_SUPPORTED, "Distributed sort is not supported with automatic retries enabled"); + } + } + return retryPolicy; + } + + public static int getRetryAttempts(Session session) + { + return session.getSystemProperty(RETRY_ATTEMPTS, Integer.class); + } + + public static Duration getRetryInitialDelay(Session session) + { + return session.getSystemProperty(RETRY_INITIAL_DELAY, Duration.class); + } + + public static Duration getRetryMaxDelay(Session session) + { + return session.getSystemProperty(RETRY_MAX_DELAY, Duration.class); + } } diff --git a/core/trino-main/src/main/java/io/trino/event/SplitMonitor.java b/core/trino-main/src/main/java/io/trino/event/SplitMonitor.java index ea64691a6fad..75de9de0b368 100644 --- a/core/trino-main/src/main/java/io/trino/event/SplitMonitor.java +++ b/core/trino-main/src/main/java/io/trino/event/SplitMonitor.java @@ -89,7 +89,7 @@ private void splitCompletedEvent(TaskId taskId, DriverStats driverStats, @Nullab new SplitCompletedEvent( taskId.getQueryId().toString(), taskId.getStageId().toString(), - Integer.toString(taskId.getId()), + taskId.toString(), splitCatalog, driverStats.getCreateTime().toDate().toInstant(), Optional.ofNullable(driverStats.getStartTime()).map(startTime -> startTime.toDate().toInstant()), diff --git a/core/trino-main/src/main/java/io/trino/execution/DataDefinitionExecution.java b/core/trino-main/src/main/java/io/trino/execution/DataDefinitionExecution.java index 07b996c64791..4532be10f996 100644 --- a/core/trino-main/src/main/java/io/trino/execution/DataDefinitionExecution.java +++ b/core/trino-main/src/main/java/io/trino/execution/DataDefinitionExecution.java @@ -183,6 +183,12 @@ public void addOutputInfoListener(Consumer listener) // DDL does not have an output } + @Override + public void outputTaskFailed(TaskId taskId, Throwable failure) + { + // DDL does not have an output + } + @Override public ListenableFuture getStateChange(QueryState currentState) { diff --git a/core/trino-main/src/main/java/io/trino/execution/FailureInjectionConfig.java b/core/trino-main/src/main/java/io/trino/execution/FailureInjectionConfig.java new file mode 100644 index 000000000000..db6379879b2a --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/FailureInjectionConfig.java @@ -0,0 +1,56 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution; + +import io.airlift.configuration.Config; +import io.airlift.configuration.ConfigDescription; +import io.airlift.units.Duration; + +import javax.validation.constraints.NotNull; + +import static java.util.concurrent.TimeUnit.MINUTES; + +public class FailureInjectionConfig +{ + private Duration expirationPeriod = new Duration(10, MINUTES); + private Duration requestTimeout = new Duration(2, MINUTES); + + @NotNull + public Duration getExpirationPeriod() + { + return expirationPeriod; + } + + @Config("failure-injection.expiration-period") + @ConfigDescription("Period after which an injected failure is considered expired and will no longer be triggering a failure") + public FailureInjectionConfig setExpirationPeriod(Duration expirationPeriod) + { + this.expirationPeriod = expirationPeriod; + return this; + } + + @NotNull + public Duration getRequestTimeout() + { + return requestTimeout; + } + + @Config("failure-injection.request-timeout") + @ConfigDescription("Period after which requests blocked to emulate a timeout are released") + public FailureInjectionConfig setRequestTimeout(Duration requestTimeout) + { + this.requestTimeout = requestTimeout; + return this; + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/FailureInjector.java b/core/trino-main/src/main/java/io/trino/execution/FailureInjector.java new file mode 100644 index 000000000000..43b74ce2cbd2 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/FailureInjector.java @@ -0,0 +1,209 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution; + +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; +import io.airlift.units.Duration; +import io.trino.spi.ErrorCode; +import io.trino.spi.ErrorCodeSupplier; +import io.trino.spi.ErrorType; +import io.trino.spi.TrinoException; + +import javax.inject.Inject; + +import java.util.Objects; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; +import static io.trino.execution.FailureInjector.InjectedFailureType.TASK_FAILURE; +import static io.trino.spi.ErrorType.EXTERNAL; +import static io.trino.spi.ErrorType.INSUFFICIENT_RESOURCES; +import static io.trino.spi.ErrorType.INTERNAL_ERROR; +import static io.trino.spi.ErrorType.USER_ERROR; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.MILLISECONDS; + +public class FailureInjector +{ + public static final String FAILURE_INJECTION_MESSAGE = "This error is injected by the failure injection service"; + + private final Cache failures; + private final Duration requestTimeout; + + @Inject + public FailureInjector(FailureInjectionConfig config) + { + this( + requireNonNull(config, "config is null").getExpirationPeriod(), + config.getRequestTimeout()); + } + + public FailureInjector(Duration expirationPeriod, Duration requestTimeout) + { + failures = CacheBuilder.newBuilder() + .expireAfterWrite(expirationPeriod.toMillis(), MILLISECONDS) + .build(); + this.requestTimeout = requireNonNull(requestTimeout, "requestTimeout is null"); + } + + public void injectTaskFailure( + String traceToken, + int stageId, + int partitionId, + int attemptId, + InjectedFailureType injectionType, + Optional errorType) + { + failures.put(new Key(traceToken, stageId, partitionId, attemptId), new InjectedFailure(injectionType, errorType)); + } + + public Optional getInjectedFailure( + String traceToken, + int stageId, + int partitionId, + int attemptId) + { + if (failures.size() == 0) { + return Optional.empty(); + } + return Optional.ofNullable(failures.getIfPresent(new Key(traceToken, stageId, partitionId, attemptId))); + } + + public Duration getRequestTimeout() + { + return requestTimeout; + } + + private static class Key + { + private final String traceToken; + private final int stageId; + private final int partitionId; + private final int attemptId; + + private Key(String traceToken, int stageId, int partitionId, int attemptId) + { + this.traceToken = requireNonNull(traceToken, "traceToken is null"); + this.stageId = stageId; + this.partitionId = partitionId; + this.attemptId = attemptId; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Key key = (Key) o; + return stageId == key.stageId && partitionId == key.partitionId && attemptId == key.attemptId && Objects.equals(traceToken, key.traceToken); + } + + @Override + public int hashCode() + { + return Objects.hash(traceToken, stageId, partitionId, attemptId); + } + } + + public enum InjectedFailureType + { + TASK_MANAGEMENT_REQUEST_FAILURE, + TASK_MANAGEMENT_REQUEST_TIMEOUT, + TASK_GET_RESULTS_REQUEST_FAILURE, + TASK_GET_RESULTS_REQUEST_TIMEOUT, + TASK_FAILURE, + } + + public static class InjectedFailure + { + private final InjectedFailureType injectedFailureType; + private final Optional taskFailureErrorType; + + public InjectedFailure(InjectedFailureType injectedFailureType, Optional taskFailureErrorType) + { + this.injectedFailureType = requireNonNull(injectedFailureType, "injectedFailureType is null"); + this.taskFailureErrorType = requireNonNull(taskFailureErrorType, "taskFailureErrorType is null"); + if (injectedFailureType == TASK_FAILURE) { + checkArgument(taskFailureErrorType.isPresent(), "error type must be present when injection type is task failure"); + } + else { + checkArgument(taskFailureErrorType.isEmpty(), "error type must not be present when injection type is not task failure"); + } + } + + public InjectedFailureType getInjectedFailureType() + { + return injectedFailureType; + } + + public ErrorType getTaskFailureErrorType() + { + return taskFailureErrorType.orElseThrow(() -> new IllegalStateException("this method must only be called for failure type of TASK_FAILURE")); + } + + public Throwable getTaskFailureException() + { + ErrorType errorType = getTaskFailureErrorType(); + return new TrinoException(InjectedErrorCode.getErrorCode(errorType), FAILURE_INJECTION_MESSAGE); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("injectedFailureType", injectedFailureType) + .add("taskFailureErrorType", taskFailureErrorType) + .toString(); + } + } + + public enum InjectedErrorCode + implements ErrorCodeSupplier + { + INJECTED_USER_ERROR(1, USER_ERROR), + INJECTED_INTERNAL_ERROR(2, INTERNAL_ERROR), + INJECTED_INSUFFICIENT_RESOURCES_ERROR(3, INSUFFICIENT_RESOURCES), + INJECTED_EXTERNAL_ERROR(4, EXTERNAL), + /**/; + + private final ErrorCode errorCode; + + InjectedErrorCode(int code, ErrorType type) + { + errorCode = new ErrorCode(code + 0x30000, name(), type); + } + + @Override + public ErrorCode toErrorCode() + { + return errorCode; + } + + public static InjectedErrorCode getErrorCode(ErrorType errorType) + { + for (InjectedErrorCode code : values()) { + if (code.toErrorCode().getType() == errorType) { + return code; + } + } + throw new IllegalArgumentException("unexpected error type: " + errorType); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/QueryExecution.java b/core/trino-main/src/main/java/io/trino/execution/QueryExecution.java index b48bb2d567a3..aa52e399bf1d 100644 --- a/core/trino-main/src/main/java/io/trino/execution/QueryExecution.java +++ b/core/trino-main/src/main/java/io/trino/execution/QueryExecution.java @@ -14,7 +14,7 @@ package io.trino.execution; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; +import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.ListenableFuture; import io.airlift.units.DataSize; import io.airlift.units.Duration; @@ -30,7 +30,7 @@ import java.net.URI; import java.util.List; -import java.util.Set; +import java.util.Map; import java.util.function.Consumer; import static java.util.Objects.requireNonNull; @@ -46,6 +46,8 @@ public interface QueryExecution void addOutputInfoListener(Consumer listener); + void outputTaskFailed(TaskId taskId, Throwable failure); + Plan getQueryPlan(); BasicQueryInfo getBasicQueryInfo(); @@ -96,14 +98,14 @@ class QueryOutputInfo { private final List columnNames; private final List columnTypes; - private final Set bufferLocations; + private final Map bufferLocations; private final boolean noMoreBufferLocations; - public QueryOutputInfo(List columnNames, List columnTypes, Set bufferLocations, boolean noMoreBufferLocations) + public QueryOutputInfo(List columnNames, List columnTypes, Map bufferLocations, boolean noMoreBufferLocations) { this.columnNames = ImmutableList.copyOf(requireNonNull(columnNames, "columnNames is null")); this.columnTypes = ImmutableList.copyOf(requireNonNull(columnTypes, "columnTypes is null")); - this.bufferLocations = ImmutableSet.copyOf(requireNonNull(bufferLocations, "bufferLocations is null")); + this.bufferLocations = ImmutableMap.copyOf(requireNonNull(bufferLocations, "bufferLocations is null")); this.noMoreBufferLocations = noMoreBufferLocations; } @@ -117,7 +119,7 @@ public List getColumnTypes() return columnTypes; } - public Set getBufferLocations() + public Map getBufferLocations() { return bufferLocations; } diff --git a/core/trino-main/src/main/java/io/trino/execution/QueryManager.java b/core/trino-main/src/main/java/io/trino/execution/QueryManager.java index d15c20ffbda6..27a97d035db5 100644 --- a/core/trino-main/src/main/java/io/trino/execution/QueryManager.java +++ b/core/trino-main/src/main/java/io/trino/execution/QueryManager.java @@ -36,6 +36,11 @@ public interface QueryManager void addOutputInfoListener(QueryId queryId, Consumer listener) throws NoSuchElementException; + /** + * Notify that one of the output tasks failed for a given query + */ + void outputTaskFailed(TaskId taskId, Throwable failure); + /** * Add a listener that fires each time the query state changes. * Listener is always notified asynchronously using a dedicated notification thread pool so, care should diff --git a/core/trino-main/src/main/java/io/trino/execution/QueryStateMachine.java b/core/trino-main/src/main/java/io/trino/execution/QueryStateMachine.java index 81a4e8827858..90e09fad9038 100644 --- a/core/trino-main/src/main/java/io/trino/execution/QueryStateMachine.java +++ b/core/trino-main/src/main/java/io/trino/execution/QueryStateMachine.java @@ -58,8 +58,9 @@ import java.net.URI; import java.util.ArrayList; +import java.util.HashMap; import java.util.HashSet; -import java.util.LinkedHashSet; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -644,12 +645,22 @@ public void addOutputInfoListener(Consumer listener) outputManager.addOutputInfoListener(listener); } + public void addOutputTaskFailureListener(TaskFailureListener listener) + { + outputManager.addOutputTaskFailureListener(listener); + } + + public void outputTaskFailed(TaskId taskId, Throwable failure) + { + outputManager.outputTaskFailed(taskId, failure); + } + public void setColumns(List columnNames, List columnTypes) { outputManager.setColumns(columnNames, columnTypes); } - public void updateOutputLocations(Set newExchangeLocations, boolean noMoreExchangeLocations) + public void updateOutputLocations(Map newExchangeLocations, boolean noMoreExchangeLocations) { outputManager.updateOutputLocations(newExchangeLocations, noMoreExchangeLocations); } @@ -1037,7 +1048,7 @@ private static boolean isScheduled(Optional rootStage) } return getAllStages(rootStage).stream() .map(StageInfo::getState) - .allMatch(state -> state == StageState.RUNNING || state == StageState.FLUSHING || state.isDone()); + .allMatch(state -> state == StageState.RUNNING || state == StageState.PENDING || state.isDone()); } public Optional getFailureInfo() @@ -1074,6 +1085,7 @@ public void pruneQueryInfo() outputStage.getStageId(), outputStage.getState(), null, // Remove the plan + outputStage.isCoordinatorOnly(), outputStage.getTypes(), outputStage.getStageStats(), ImmutableList.of(), // Remove the tasks @@ -1187,10 +1199,15 @@ public static class QueryOutputManager @GuardedBy("this") private List columnTypes; @GuardedBy("this") - private final Set exchangeLocations = new LinkedHashSet<>(); + private final Map exchangeLocations = new LinkedHashMap<>(); @GuardedBy("this") private boolean noMoreExchangeLocations; + @GuardedBy("this") + private final Map outputTaskFailures = new HashMap<>(); + @GuardedBy("this") + private final List outputTaskFailureListeners = new ArrayList<>(); + public QueryOutputManager(Executor executor) { this.executor = requireNonNull(executor, "executor is null"); @@ -1227,7 +1244,7 @@ public void setColumns(List columnNames, List columnTypes) queryOutputInfo.ifPresent(info -> fireStateChanged(info, outputInfoListeners)); } - public void updateOutputLocations(Set newExchangeLocations, boolean noMoreExchangeLocations) + public void updateOutputLocations(Map newExchangeLocations, boolean noMoreExchangeLocations) { requireNonNull(newExchangeLocations, "newExchangeLocations is null"); @@ -1235,11 +1252,11 @@ public void updateOutputLocations(Set newExchangeLocations, boolean noMoreE List> outputInfoListeners; synchronized (this) { if (this.noMoreExchangeLocations) { - checkArgument(this.exchangeLocations.containsAll(newExchangeLocations), "New locations added after no more locations set"); + checkArgument(this.exchangeLocations.entrySet().containsAll(newExchangeLocations.entrySet()), "New locations added after no more locations set"); return; } - this.exchangeLocations.addAll(newExchangeLocations); + this.exchangeLocations.putAll(newExchangeLocations); this.noMoreExchangeLocations = noMoreExchangeLocations; queryOutputInfo = getQueryOutputInfo(); outputInfoListeners = ImmutableList.copyOf(this.outputInfoListeners); @@ -1247,6 +1264,32 @@ public void updateOutputLocations(Set newExchangeLocations, boolean noMoreE queryOutputInfo.ifPresent(info -> fireStateChanged(info, outputInfoListeners)); } + public void addOutputTaskFailureListener(TaskFailureListener listener) + { + Map failures; + synchronized (this) { + outputTaskFailureListeners.add(listener); + failures = ImmutableMap.copyOf(outputTaskFailures); + } + executor.execute(() -> { + failures.forEach(listener::onTaskFailed); + }); + } + + public void outputTaskFailed(TaskId taskId, Throwable failure) + { + List listeners; + synchronized (this) { + outputTaskFailures.putIfAbsent(taskId, failure); + listeners = ImmutableList.copyOf(outputTaskFailureListeners); + } + executor.execute(() -> { + for (TaskFailureListener listener : listeners) { + listener.onTaskFailed(taskId, failure); + } + }); + } + private synchronized Optional getQueryOutputInfo() { if (columnNames == null || columnTypes == null) { diff --git a/core/trino-main/src/main/java/io/trino/execution/RemoteTask.java b/core/trino-main/src/main/java/io/trino/execution/RemoteTask.java index 3320faf99c44..f19d81a2585d 100644 --- a/core/trino-main/src/main/java/io/trino/execution/RemoteTask.java +++ b/core/trino-main/src/main/java/io/trino/execution/RemoteTask.java @@ -63,6 +63,8 @@ public interface RemoteTask PartitionedSplitsInfo getPartitionedSplitsInfo(); + void fail(Throwable cause); + PartitionedSplitsInfo getQueuedPartitionedSplitsInfo(); int getUnacknowledgedPartitionedSplitCount(); diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java b/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java index b9e706c79c0f..bfae1a6636f5 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java @@ -15,7 +15,6 @@ import com.google.common.util.concurrent.ListenableFuture; import io.airlift.concurrent.SetThreadName; -import io.airlift.log.Logger; import io.airlift.units.DataSize; import io.airlift.units.Duration; import io.trino.Session; @@ -25,8 +24,6 @@ import io.trino.cost.StatsCalculator; import io.trino.execution.QueryPreparer.PreparedQuery; import io.trino.execution.StateMachine.StateChangeListener; -import io.trino.execution.buffer.OutputBuffers; -import io.trino.execution.buffer.OutputBuffers.OutputBufferId; import io.trino.execution.scheduler.ExecutionPolicy; import io.trino.execution.scheduler.NodeScheduler; import io.trino.execution.scheduler.SplitSchedulerStats; @@ -43,24 +40,22 @@ import io.trino.spi.QueryId; import io.trino.spi.TrinoException; import io.trino.spi.type.TypeOperators; -import io.trino.split.SplitManager; -import io.trino.split.SplitSource; import io.trino.sql.analyzer.Analysis; import io.trino.sql.analyzer.Analyzer; import io.trino.sql.analyzer.AnalyzerFactory; -import io.trino.sql.planner.DistributedExecutionPlanner; import io.trino.sql.planner.InputExtractor; import io.trino.sql.planner.LogicalPlanner; import io.trino.sql.planner.NodePartitioningManager; -import io.trino.sql.planner.PartitioningHandle; import io.trino.sql.planner.Plan; +import io.trino.sql.planner.PlanFragment; import io.trino.sql.planner.PlanFragmenter; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.PlanOptimizersFactory; -import io.trino.sql.planner.StageExecutionPlan; +import io.trino.sql.planner.SplitSourceFactory; import io.trino.sql.planner.SubPlan; import io.trino.sql.planner.TypeAnalyzer; import io.trino.sql.planner.optimizations.PlanOptimizer; +import io.trino.sql.planner.plan.OutputNode; import io.trino.sql.tree.ExplainAnalyze; import io.trino.sql.tree.Query; import io.trino.sql.tree.Statement; @@ -86,9 +81,6 @@ import static io.trino.SystemSessionProperties.isEnableDynamicFiltering; import static io.trino.execution.QueryState.FAILED; import static io.trino.execution.QueryState.PLANNING; -import static io.trino.execution.buffer.OutputBuffers.BROADCAST_PARTITION_ID; -import static io.trino.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers; -import static io.trino.execution.scheduler.SqlQueryScheduler.createSqlQueryScheduler; import static io.trino.server.DynamicFilterService.DynamicFiltersStats; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.sql.ParameterUtils.parameterExtractor; @@ -100,15 +92,11 @@ public class SqlQueryExecution implements QueryExecution { - private static final Logger log = Logger.get(SqlQueryExecution.class); - - private static final OutputBufferId OUTPUT_BUFFER_ID = new OutputBufferId(0); - private final QueryStateMachine stateMachine; private final Slug slug; private final Metadata metadata; private final TypeOperators typeOperators; - private final SplitManager splitManager; + private final SplitSourceFactory splitSourceFactory; private final NodePartitioningManager nodePartitioningManager; private final NodeScheduler nodeScheduler; private final List planOptimizers; @@ -130,6 +118,7 @@ public class SqlQueryExecution private final DynamicFilterService dynamicFilterService; private final TableExecuteContextManager tableExecuteContextManager; private final TypeAnalyzer typeAnalyzer; + private final TaskManager coordinatorTaskManager; private SqlQueryExecution( PreparedQuery preparedQuery, @@ -138,7 +127,7 @@ private SqlQueryExecution( Metadata metadata, TypeOperators typeOperators, AnalyzerFactory analyzerFactory, - SplitManager splitManager, + SplitSourceFactory splitSourceFactory, NodePartitioningManager nodePartitioningManager, NodeScheduler nodeScheduler, List planOptimizers, @@ -156,13 +145,14 @@ private SqlQueryExecution( DynamicFilterService dynamicFilterService, WarningCollector warningCollector, TableExecuteContextManager tableExecuteContextManager, - TypeAnalyzer typeAnalyzer) + TypeAnalyzer typeAnalyzer, + TaskManager coordinatorTaskManager) { try (SetThreadName ignored = new SetThreadName("Query-%s", stateMachine.getQueryId())) { this.slug = requireNonNull(slug, "slug is null"); this.metadata = requireNonNull(metadata, "metadata is null"); this.typeOperators = requireNonNull(typeOperators, "typeOperators is null"); - this.splitManager = requireNonNull(splitManager, "splitManager is null"); + this.splitSourceFactory = requireNonNull(splitSourceFactory, "splitSourceFactory is null"); this.nodePartitioningManager = requireNonNull(nodePartitioningManager, "nodePartitioningManager is null"); this.nodeScheduler = requireNonNull(nodeScheduler, "nodeScheduler is null"); this.planOptimizers = requireNonNull(planOptimizers, "planOptimizers is null"); @@ -212,6 +202,7 @@ private SqlQueryExecution( this.remoteTaskFactory = new MemoryTrackingRemoteTaskFactory(requireNonNull(remoteTaskFactory, "remoteTaskFactory is null"), stateMachine); this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); + this.coordinatorTaskManager = requireNonNull(coordinatorTaskManager, "coordinatorTaskManager is null"); } } @@ -495,49 +486,37 @@ private PlanRoot doPlanQuery() private void planDistribution(PlanRoot plan) { - // plan the execution on the active nodes - DistributedExecutionPlanner distributedPlanner = new DistributedExecutionPlanner(splitManager, metadata, dynamicFilterService, typeAnalyzer); - StageExecutionPlan outputStageExecutionPlan = distributedPlanner.plan(plan.getRoot(), stateMachine.getSession()); - - // ensure split sources are closed - stateMachine.addStateChangeListener(state -> { - if (state.isDone()) { - closeSplitSources(outputStageExecutionPlan); - } - }); - // if query was canceled, skip creating scheduler if (stateMachine.isDone()) { return; } // record output field - stateMachine.setColumns(outputStageExecutionPlan.getFieldNames(), outputStageExecutionPlan.getFragment().getTypes()); - - PartitioningHandle partitioningHandle = plan.getRoot().getFragment().getPartitioningScheme().getPartitioning().getHandle(); - OutputBuffers rootOutputBuffers = createInitialEmptyOutputBuffers(partitioningHandle) - .withBuffer(OUTPUT_BUFFER_ID, BROADCAST_PARTITION_ID) - .withNoMoreBufferIds(); + PlanFragment rootFragment = plan.getRoot().getFragment(); + stateMachine.setColumns( + ((OutputNode) rootFragment.getRoot()).getColumnNames(), + rootFragment.getTypes()); // build the stage execution objects (this doesn't schedule execution) - SqlQueryScheduler scheduler = createSqlQueryScheduler( + SqlQueryScheduler scheduler = new SqlQueryScheduler( stateMachine, - outputStageExecutionPlan, + plan.getRoot(), nodePartitioningManager, nodeScheduler, remoteTaskFactory, - stateMachine.getSession(), plan.isSummarizeTaskInfos(), scheduleSplitBatchSize, queryExecutor, schedulerExecutor, failureDetector, - rootOutputBuffers, nodeTaskMap, executionPolicy, schedulerStats, dynamicFilterService, - tableExecuteContextManager); + tableExecuteContextManager, + metadata, + splitSourceFactory, + coordinatorTaskManager); queryScheduler.set(scheduler); @@ -549,22 +528,6 @@ private void planDistribution(PlanRoot plan) } } - private static void closeSplitSources(StageExecutionPlan plan) - { - for (SplitSource source : plan.getSplitSources().values()) { - try { - source.close(); - } - catch (Throwable t) { - log.warn(t, "Error closing split source"); - } - } - - for (StageExecutionPlan stage : plan.getSubStages()) { - closeSplitSources(stage); - } - } - @Override public void cancelQuery() { @@ -604,6 +567,12 @@ public void addOutputInfoListener(Consumer listener) stateMachine.addOutputInfoListener(listener); } + @Override + public void outputTaskFailed(TaskId taskId, Throwable failure) + { + stateMachine.outputTaskFailed(taskId, failure); + } + @Override public ListenableFuture getStateChange(QueryState currentState) { @@ -717,7 +686,7 @@ public static class SqlQueryExecutionFactory private final Metadata metadata; private final TypeOperators typeOperators; private final AnalyzerFactory analyzerFactory; - private final SplitManager splitManager; + private final SplitSourceFactory splitSourceFactory; private final NodePartitioningManager nodePartitioningManager; private final NodeScheduler nodeScheduler; private final List planOptimizers; @@ -733,6 +702,7 @@ public static class SqlQueryExecutionFactory private final DynamicFilterService dynamicFilterService; private final TableExecuteContextManager tableExecuteContextManager; private final TypeAnalyzer typeAnalyzer; + private final TaskManager coordinatorTaskManager; @Inject SqlQueryExecutionFactory( @@ -740,7 +710,7 @@ public static class SqlQueryExecutionFactory Metadata metadata, TypeOperators typeOperators, AnalyzerFactory analyzerFactory, - SplitManager splitManager, + SplitSourceFactory splitSourceFactory, NodePartitioningManager nodePartitioningManager, NodeScheduler nodeScheduler, PlanOptimizersFactory planOptimizersFactory, @@ -756,7 +726,8 @@ public static class SqlQueryExecutionFactory CostCalculator costCalculator, DynamicFilterService dynamicFilterService, TableExecuteContextManager tableExecuteContextManager, - TypeAnalyzer typeAnalyzer) + TypeAnalyzer typeAnalyzer, + TaskManager coordinatorTaskManager) { requireNonNull(config, "config is null"); this.schedulerStats = requireNonNull(schedulerStats, "schedulerStats is null"); @@ -764,7 +735,7 @@ public static class SqlQueryExecutionFactory this.metadata = requireNonNull(metadata, "metadata is null"); this.typeOperators = requireNonNull(typeOperators, "typeOperators is null"); this.analyzerFactory = requireNonNull(analyzerFactory, "analyzerFactory is null"); - this.splitManager = requireNonNull(splitManager, "splitManager is null"); + this.splitSourceFactory = requireNonNull(splitSourceFactory, "splitSourceFactory is null"); this.nodePartitioningManager = requireNonNull(nodePartitioningManager, "nodePartitioningManager is null"); this.nodeScheduler = requireNonNull(nodeScheduler, "nodeScheduler is null"); this.planFragmenter = requireNonNull(planFragmenter, "planFragmenter is null"); @@ -780,6 +751,7 @@ public static class SqlQueryExecutionFactory this.dynamicFilterService = requireNonNull(dynamicFilterService, "dynamicFilterService is null"); this.tableExecuteContextManager = requireNonNull(tableExecuteContextManager, "tableExecuteContextManager is null"); this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); + this.coordinatorTaskManager = requireNonNull(coordinatorTaskManager, "coordinatorTaskManager is null"); } @Override @@ -800,7 +772,7 @@ public QueryExecution createQueryExecution( metadata, typeOperators, analyzerFactory, - splitManager, + splitSourceFactory, nodePartitioningManager, nodeScheduler, planOptimizers, @@ -818,7 +790,8 @@ public QueryExecution createQueryExecution( dynamicFilterService, warningCollector, tableExecuteContextManager, - typeAnalyzer); + typeAnalyzer, + coordinatorTaskManager); } } } diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlQueryManager.java b/core/trino-main/src/main/java/io/trino/execution/SqlQueryManager.java index 1715cd6825ac..695f1a924b53 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlQueryManager.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlQueryManager.java @@ -155,6 +155,12 @@ public void addOutputInfoListener(QueryId queryId, Consumer lis queryTracker.getQuery(queryId).addOutputInfoListener(listener); } + @Override + public void outputTaskFailed(TaskId taskId, Throwable failure) + { + queryTracker.getQuery(taskId.getQueryId()).outputTaskFailed(taskId, failure); + } + @Override public void addStateChangeListener(QueryId queryId, StateChangeListener listener) { diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlStage.java b/core/trino-main/src/main/java/io/trino/execution/SqlStage.java new file mode 100644 index 000000000000..d81e49861aa0 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/SqlStage.java @@ -0,0 +1,304 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution; + +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Multimap; +import io.airlift.units.Duration; +import io.trino.Session; +import io.trino.execution.StateMachine.StateChangeListener; +import io.trino.execution.buffer.OutputBuffers; +import io.trino.execution.scheduler.SplitSchedulerStats; +import io.trino.metadata.InternalNode; +import io.trino.metadata.Split; +import io.trino.sql.planner.PlanFragment; +import io.trino.sql.planner.plan.DynamicFilterId; +import io.trino.sql.planner.plan.PlanNodeId; + +import javax.annotation.concurrent.GuardedBy; +import javax.annotation.concurrent.ThreadSafe; + +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.SystemSessionProperties.isEnableCoordinatorDynamicFiltersDistribution; +import static io.trino.server.DynamicFilterService.getOutboundDynamicFilters; +import static java.util.Objects.requireNonNull; + +@ThreadSafe +public final class SqlStage +{ + private final Session session; + private final StageStateMachine stateMachine; + private final RemoteTaskFactory remoteTaskFactory; + private final NodeTaskMap nodeTaskMap; + private final boolean summarizeTaskInfo; + + private final Set outboundDynamicFilterIds; + + private final Map tasks = new ConcurrentHashMap<>(); + @GuardedBy("this") + private final Set allTasks = new HashSet<>(); + @GuardedBy("this") + private final Set finishedTasks = new HashSet<>(); + @GuardedBy("this") + private final Set tasksWithFinalInfo = new HashSet<>(); + + public static SqlStage createSqlStage( + StageId stageId, + PlanFragment fragment, + Map tables, + RemoteTaskFactory remoteTaskFactory, + Session session, + boolean summarizeTaskInfo, + NodeTaskMap nodeTaskMap, + Executor executor, + SplitSchedulerStats schedulerStats) + { + requireNonNull(stageId, "stageId is null"); + requireNonNull(fragment, "fragment is null"); + checkArgument(fragment.getPartitioningScheme().getBucketToPartition().isEmpty(), "bucket to partition is not expected to be set at this point"); + requireNonNull(tables, "tables is null"); + requireNonNull(remoteTaskFactory, "remoteTaskFactory is null"); + requireNonNull(session, "session is null"); + requireNonNull(nodeTaskMap, "nodeTaskMap is null"); + requireNonNull(executor, "executor is null"); + requireNonNull(schedulerStats, "schedulerStats is null"); + + SqlStage sqlStage = new SqlStage( + session, + new StageStateMachine(stageId, fragment, tables, executor, schedulerStats), + remoteTaskFactory, + nodeTaskMap, + summarizeTaskInfo); + sqlStage.initialize(); + return sqlStage; + } + + private SqlStage( + Session session, + StageStateMachine stateMachine, + RemoteTaskFactory remoteTaskFactory, + NodeTaskMap nodeTaskMap, + boolean summarizeTaskInfo) + { + this.session = requireNonNull(session, "session is null"); + this.stateMachine = stateMachine; + this.remoteTaskFactory = requireNonNull(remoteTaskFactory, "remoteTaskFactory is null"); + this.nodeTaskMap = requireNonNull(nodeTaskMap, "nodeTaskMap is null"); + this.summarizeTaskInfo = summarizeTaskInfo; + + if (isEnableCoordinatorDynamicFiltersDistribution(session)) { + this.outboundDynamicFilterIds = getOutboundDynamicFilters(stateMachine.getFragment()); + } + else { + this.outboundDynamicFilterIds = ImmutableSet.of(); + } + } + + // this is a separate method to ensure that the `this` reference is not leaked during construction + private void initialize() + { + stateMachine.addStateChangeListener(newState -> checkAllTaskFinal()); + } + + public StageId getStageId() + { + return stateMachine.getStageId(); + } + + public synchronized void finish() + { + stateMachine.transitionToFinished(); + tasks.values().forEach(RemoteTask::cancel); + } + + public synchronized void abort() + { + stateMachine.transitionToAborted(); + tasks.values().forEach(RemoteTask::abort); + } + + public synchronized void fail(Throwable throwable) + { + requireNonNull(throwable, "throwable is null"); + stateMachine.transitionToFailed(throwable); + tasks.values().forEach(RemoteTask::abort); + } + + /** + * Add a listener for the final stage info. This notification is guaranteed to be fired only once. + * Listener is always notified asynchronously using a dedicated notification thread pool so, care should + * be taken to avoid leaking {@code this} when adding a listener in a constructor. Additionally, it is + * possible notifications are observed out of order due to the asynchronous execution. + */ + public void addFinalStageInfoListener(StateChangeListener stateChangeListener) + { + stateMachine.addFinalStageInfoListener(stateChangeListener); + } + + public PlanFragment getFragment() + { + return stateMachine.getFragment(); + } + + public long getUserMemoryReservation() + { + return stateMachine.getUserMemoryReservation(); + } + + public long getTotalMemoryReservation() + { + return stateMachine.getTotalMemoryReservation(); + } + + public Duration getTotalCpuTime() + { + long millis = tasks.values().stream() + .mapToLong(task -> task.getTaskInfo().getStats().getTotalCpuTime().toMillis()) + .sum(); + return new Duration(millis, TimeUnit.MILLISECONDS); + } + + public BasicStageStats getBasicStageStats() + { + return stateMachine.getBasicStageStats(this::getAllTaskInfo); + } + + public StageInfo getStageInfo() + { + return stateMachine.getStageInfo(this::getAllTaskInfo); + } + + private Iterable getAllTaskInfo() + { + return tasks.values().stream() + .map(RemoteTask::getTaskInfo) + .collect(toImmutableList()); + } + + public synchronized Optional createTask( + InternalNode node, + int partition, + int attempt, + Optional bucketToPartition, + OutputBuffers outputBuffers, + Multimap splits, + Multimap noMoreSplitsForLifespan, + Set noMoreSplits) + { + if (stateMachine.getState().isDone()) { + return Optional.empty(); + } + TaskId taskId = new TaskId(stateMachine.getStageId(), partition, attempt); + checkArgument(!tasks.containsKey(taskId), "A task with id %s already exists", taskId); + + stateMachine.transitionToScheduling(); + + RemoteTask task = remoteTaskFactory.createRemoteTask( + session, + taskId, + node, + stateMachine.getFragment().withBucketToPartition(bucketToPartition), + splits, + outputBuffers, + nodeTaskMap.createPartitionedSplitCountTracker(node, taskId), + outboundDynamicFilterIds, + summarizeTaskInfo); + + noMoreSplitsForLifespan.forEach(task::noMoreSplits); + noMoreSplits.forEach(task::noMoreSplits); + + tasks.put(taskId, task); + allTasks.add(taskId); + nodeTaskMap.addTask(node, task); + + task.addStateChangeListener(this::updateTaskStatus); + task.addStateChangeListener(new MemoryUsageListener()); + task.addFinalTaskInfoListener(this::updateFinalTaskInfo); + + return Optional.of(task); + } + + public void recordGetSplitTime(long start) + { + stateMachine.recordGetSplitTime(start); + } + + private synchronized void updateTaskStatus(TaskStatus status) + { + if (status.getState().isDone()) { + finishedTasks.add(status.getTaskId()); + } + if (!finishedTasks.containsAll(allTasks)) { + stateMachine.transitionToRunning(); + } + else { + stateMachine.transitionToPending(); + } + } + + private synchronized void updateFinalTaskInfo(TaskInfo finalTaskInfo) + { + tasksWithFinalInfo.add(finalTaskInfo.getTaskStatus().getTaskId()); + checkAllTaskFinal(); + } + + private synchronized void checkAllTaskFinal() + { + if (stateMachine.getState().isDone() && tasksWithFinalInfo.containsAll(tasks.keySet())) { + List finalTaskInfos = tasks.values().stream() + .map(RemoteTask::getTaskInfo) + .collect(toImmutableList()); + stateMachine.setAllTasksFinal(finalTaskInfos); + } + } + + @Override + public String toString() + { + return stateMachine.toString(); + } + + private class MemoryUsageListener + implements StateChangeListener + { + private long previousUserMemory; + private long previousSystemMemory; + private long previousRevocableMemory; + + @Override + public synchronized void stateChanged(TaskStatus taskStatus) + { + long currentUserMemory = taskStatus.getMemoryReservation().toBytes(); + long currentSystemMemory = taskStatus.getSystemMemoryReservation().toBytes(); + long currentRevocableMemory = taskStatus.getRevocableMemoryReservation().toBytes(); + long deltaUserMemoryInBytes = currentUserMemory - previousUserMemory; + long deltaRevocableMemoryInBytes = currentRevocableMemory - previousRevocableMemory; + long deltaTotalMemoryInBytes = (currentUserMemory + currentSystemMemory + currentRevocableMemory) - (previousUserMemory + previousSystemMemory + previousRevocableMemory); + previousUserMemory = currentUserMemory; + previousSystemMemory = currentSystemMemory; + previousRevocableMemory = currentRevocableMemory; + stateMachine.updateMemoryUsage(deltaUserMemoryInBytes, deltaRevocableMemoryInBytes, deltaTotalMemoryInBytes); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlStageExecution.java b/core/trino-main/src/main/java/io/trino/execution/SqlStageExecution.java deleted file mode 100644 index 8b38208a1265..000000000000 --- a/core/trino-main/src/main/java/io/trino/execution/SqlStageExecution.java +++ /dev/null @@ -1,684 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.execution; - -import com.google.common.collect.HashMultimap; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableMultimap; -import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Multimap; -import com.google.common.collect.Sets; -import io.airlift.units.Duration; -import io.trino.Session; -import io.trino.execution.StateMachine.StateChangeListener; -import io.trino.execution.buffer.OutputBuffers; -import io.trino.execution.scheduler.SplitSchedulerStats; -import io.trino.failuredetector.FailureDetector; -import io.trino.metadata.InternalNode; -import io.trino.metadata.Split; -import io.trino.server.DynamicFilterService; -import io.trino.spi.TrinoException; -import io.trino.split.RemoteSplit; -import io.trino.sql.planner.PlanFragment; -import io.trino.sql.planner.plan.DynamicFilterId; -import io.trino.sql.planner.plan.PlanFragmentId; -import io.trino.sql.planner.plan.PlanNodeId; -import io.trino.sql.planner.plan.RemoteSourceNode; - -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; - -import java.net.URI; -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Map.Entry; -import java.util.Optional; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.Executor; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Consumer; - -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkState; -import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.collect.Sets.newConcurrentHashSet; -import static io.airlift.http.client.HttpUriBuilder.uriBuilderFrom; -import static io.trino.SystemSessionProperties.isEnableCoordinatorDynamicFiltersDistribution; -import static io.trino.failuredetector.FailureDetector.State.GONE; -import static io.trino.operator.ExchangeOperator.REMOTE_CONNECTOR_ID; -import static io.trino.server.DynamicFilterService.getOutboundDynamicFilters; -import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; -import static io.trino.spi.StandardErrorCode.REMOTE_HOST_GONE; -import static java.util.Objects.requireNonNull; - -@ThreadSafe -public final class SqlStageExecution -{ - private final StageStateMachine stateMachine; - private final RemoteTaskFactory remoteTaskFactory; - private final NodeTaskMap nodeTaskMap; - private final boolean summarizeTaskInfo; - private final Executor executor; - private final FailureDetector failureDetector; - private final DynamicFilterService dynamicFilterService; - - private final Map exchangeSources; - - private final Map> tasks = new ConcurrentHashMap<>(); - - @GuardedBy("this") - private final AtomicInteger nextTaskId = new AtomicInteger(); - @GuardedBy("this") - private final Set allTasks = newConcurrentHashSet(); - @GuardedBy("this") - private final Set finishedTasks = newConcurrentHashSet(); - @GuardedBy("this") - private final Set flushingTasks = newConcurrentHashSet(); - @GuardedBy("this") - private final Set tasksWithFinalInfo = newConcurrentHashSet(); - @GuardedBy("this") - private final AtomicBoolean splitsScheduled = new AtomicBoolean(); - - @GuardedBy("this") - private final Multimap sourceTasks = HashMultimap.create(); - @GuardedBy("this") - private final Set completeSources = newConcurrentHashSet(); - @GuardedBy("this") - private final Set completeSourceFragments = newConcurrentHashSet(); - - private final AtomicReference outputBuffers = new AtomicReference<>(); - - private final ListenerManager> completedLifespansChangeListeners = new ListenerManager<>(); - - private final Set outboundDynamicFilterIds; - - public static SqlStageExecution createSqlStageExecution( - StageId stageId, - PlanFragment fragment, - Map tables, - RemoteTaskFactory remoteTaskFactory, - Session session, - boolean summarizeTaskInfo, - NodeTaskMap nodeTaskMap, - Executor executor, - FailureDetector failureDetector, - DynamicFilterService dynamicFilterService, - SplitSchedulerStats schedulerStats) - { - requireNonNull(stageId, "stageId is null"); - requireNonNull(fragment, "fragment is null"); - requireNonNull(tables, "tables is null"); - requireNonNull(remoteTaskFactory, "remoteTaskFactory is null"); - requireNonNull(session, "session is null"); - requireNonNull(nodeTaskMap, "nodeTaskMap is null"); - requireNonNull(executor, "executor is null"); - requireNonNull(failureDetector, "failureDetector is null"); - requireNonNull(dynamicFilterService, "dynamicFilterService is null"); - requireNonNull(schedulerStats, "schedulerStats is null"); - - SqlStageExecution sqlStageExecution = new SqlStageExecution( - new StageStateMachine(stageId, session, fragment, tables, executor, schedulerStats), - remoteTaskFactory, - nodeTaskMap, - summarizeTaskInfo, - executor, - failureDetector, - dynamicFilterService); - sqlStageExecution.initialize(); - return sqlStageExecution; - } - - private SqlStageExecution( - StageStateMachine stateMachine, - RemoteTaskFactory remoteTaskFactory, - NodeTaskMap nodeTaskMap, - boolean summarizeTaskInfo, - Executor executor, - FailureDetector failureDetector, - DynamicFilterService dynamicFilterService) - { - this.stateMachine = stateMachine; - this.remoteTaskFactory = requireNonNull(remoteTaskFactory, "remoteTaskFactory is null"); - this.nodeTaskMap = requireNonNull(nodeTaskMap, "nodeTaskMap is null"); - this.summarizeTaskInfo = summarizeTaskInfo; - this.executor = requireNonNull(executor, "executor is null"); - this.failureDetector = requireNonNull(failureDetector, "failureDetector is null"); - this.dynamicFilterService = requireNonNull(dynamicFilterService, "dynamicFilterService is null"); - - ImmutableMap.Builder fragmentToExchangeSource = ImmutableMap.builder(); - for (RemoteSourceNode remoteSourceNode : stateMachine.getFragment().getRemoteSourceNodes()) { - for (PlanFragmentId planFragmentId : remoteSourceNode.getSourceFragmentIds()) { - fragmentToExchangeSource.put(planFragmentId, remoteSourceNode); - } - } - this.exchangeSources = fragmentToExchangeSource.build(); - if (isEnableCoordinatorDynamicFiltersDistribution(stateMachine.getSession())) { - this.outboundDynamicFilterIds = getOutboundDynamicFilters(stateMachine.getFragment()); - } - else { - this.outboundDynamicFilterIds = ImmutableSet.of(); - } - } - - // this is a separate method to ensure that the `this` reference is not leaked during construction - private void initialize() - { - stateMachine.addStateChangeListener(newState -> checkAllTaskFinal()); - stateMachine.addStateChangeListener(newState -> { - if (!newState.canScheduleMoreTasks()) { - dynamicFilterService.stageCannotScheduleMoreTasks(stateMachine.getStageId(), getAllTasks().size()); - } - }); - } - - public StageId getStageId() - { - return stateMachine.getStageId(); - } - - public StageState getState() - { - return stateMachine.getState(); - } - - /** - * Listener is always notified asynchronously using a dedicated notification thread pool so, care should - * be taken to avoid leaking {@code this} when adding a listener in a constructor. - */ - public void addStateChangeListener(StateChangeListener stateChangeListener) - { - stateMachine.addStateChangeListener(stateChangeListener); - } - - /** - * Add a listener for the final stage info. This notification is guaranteed to be fired only once. - * Listener is always notified asynchronously using a dedicated notification thread pool so, care should - * be taken to avoid leaking {@code this} when adding a listener in a constructor. Additionally, it is - * possible notifications are observed out of order due to the asynchronous execution. - */ - public void addFinalStageInfoListener(StateChangeListener stateChangeListener) - { - stateMachine.addFinalStageInfoListener(stateChangeListener); - } - - public void addCompletedDriverGroupsChangedListener(Consumer> newlyCompletedDriverGroupConsumer) - { - completedLifespansChangeListeners.addListener(newlyCompletedDriverGroupConsumer); - } - - public PlanFragment getFragment() - { - return stateMachine.getFragment(); - } - - public OutputBuffers getOutputBuffers() - { - return outputBuffers.get(); - } - - public void beginScheduling() - { - stateMachine.transitionToScheduling(); - } - - public synchronized void transitionToSchedulingSplits() - { - stateMachine.transitionToSchedulingSplits(); - } - - public synchronized void schedulingComplete() - { - if (!stateMachine.transitionToScheduled()) { - return; - } - - if (isFlushing()) { - stateMachine.transitionToFlushing(); - } - if (finishedTasks.containsAll(allTasks)) { - stateMachine.transitionToFinished(); - } - - for (PlanNodeId partitionedSource : stateMachine.getFragment().getPartitionedSources()) { - schedulingComplete(partitionedSource); - } - } - - public synchronized void schedulingComplete(PlanNodeId partitionedSource) - { - for (RemoteTask task : getAllTasks()) { - task.noMoreSplits(partitionedSource); - } - completeSources.add(partitionedSource); - } - - public synchronized void cancel() - { - stateMachine.transitionToCanceled(); - getAllTasks().forEach(RemoteTask::cancel); - } - - public synchronized void abort() - { - stateMachine.transitionToAborted(); - getAllTasks().forEach(RemoteTask::abort); - } - - public long getUserMemoryReservation() - { - return stateMachine.getUserMemoryReservation(); - } - - public long getTotalMemoryReservation() - { - return stateMachine.getTotalMemoryReservation(); - } - - public Duration getTotalCpuTime() - { - long millis = getAllTasks().stream() - .mapToLong(task -> task.getTaskInfo().getStats().getTotalCpuTime().toMillis()) - .sum(); - return new Duration(millis, TimeUnit.MILLISECONDS); - } - - public BasicStageStats getBasicStageStats() - { - return stateMachine.getBasicStageStats(this::getAllTaskInfo); - } - - public StageInfo getStageInfo() - { - return stateMachine.getStageInfo(this::getAllTaskInfo); - } - - private Iterable getAllTaskInfo() - { - return getAllTasks().stream() - .map(RemoteTask::getTaskInfo) - .collect(toImmutableList()); - } - - public synchronized void addExchangeLocations(PlanFragmentId fragmentId, Set sourceTasks, boolean noMoreExchangeLocations) - { - requireNonNull(fragmentId, "fragmentId is null"); - requireNonNull(sourceTasks, "sourceTasks is null"); - - RemoteSourceNode remoteSource = exchangeSources.get(fragmentId); - checkArgument(remoteSource != null, "Unknown remote source %s. Known sources are %s", fragmentId, exchangeSources.keySet()); - - this.sourceTasks.putAll(remoteSource.getId(), sourceTasks); - - for (RemoteTask task : getAllTasks()) { - ImmutableMultimap.Builder newSplits = ImmutableMultimap.builder(); - for (RemoteTask sourceTask : sourceTasks) { - URI exchangeLocation = sourceTask.getTaskStatus().getSelf(); - newSplits.put(remoteSource.getId(), createRemoteSplitFor(task.getTaskId(), exchangeLocation)); - } - task.addSplits(newSplits.build()); - } - - if (noMoreExchangeLocations) { - completeSourceFragments.add(fragmentId); - - // is the source now complete? - if (completeSourceFragments.containsAll(remoteSource.getSourceFragmentIds())) { - completeSources.add(remoteSource.getId()); - for (RemoteTask task : getAllTasks()) { - task.noMoreSplits(remoteSource.getId()); - } - } - } - } - - public synchronized void setOutputBuffers(OutputBuffers outputBuffers) - { - requireNonNull(outputBuffers, "outputBuffers is null"); - - while (true) { - OutputBuffers currentOutputBuffers = this.outputBuffers.get(); - if (currentOutputBuffers != null) { - if (outputBuffers.getVersion() <= currentOutputBuffers.getVersion()) { - return; - } - currentOutputBuffers.checkValidTransition(outputBuffers); - } - - if (this.outputBuffers.compareAndSet(currentOutputBuffers, outputBuffers)) { - for (RemoteTask task : getAllTasks()) { - task.setOutputBuffers(outputBuffers); - } - return; - } - } - } - - // do not synchronize - // this is used for query info building which should be independent of scheduling work - public boolean hasTasks() - { - return !tasks.isEmpty(); - } - - // do not synchronize - // this is used for query info building which should be independent of scheduling work - public List getAllTasks() - { - return tasks.values().stream() - .flatMap(Set::stream) - .collect(toImmutableList()); - } - - public synchronized Optional scheduleTask(InternalNode node, int partition) - { - requireNonNull(node, "node is null"); - - if (stateMachine.getState().isDone()) { - return Optional.empty(); - } - checkState(!splitsScheduled.get(), "scheduleTask cannot be called once splits have been scheduled"); - return Optional.of(scheduleTask(node, new TaskId(stateMachine.getStageId(), partition), ImmutableMultimap.of())); - } - - public synchronized Set scheduleSplits(InternalNode node, Multimap splits, Multimap noMoreSplitsNotification) - { - requireNonNull(node, "node is null"); - requireNonNull(splits, "splits is null"); - - if (stateMachine.getState().isDone()) { - return ImmutableSet.of(); - } - splitsScheduled.set(true); - - checkArgument(stateMachine.getFragment().getPartitionedSources().containsAll(splits.keySet()), "Invalid splits"); - - ImmutableSet.Builder newTasks = ImmutableSet.builder(); - Collection tasks = this.tasks.get(node); - RemoteTask task; - if (tasks == null) { - // The output buffer depends on the task id starting from 0 and being sequential, since each - // task is assigned a private buffer based on task id. - TaskId taskId = new TaskId(stateMachine.getStageId(), nextTaskId.getAndIncrement()); - task = scheduleTask(node, taskId, splits); - newTasks.add(task); - } - else { - task = tasks.iterator().next(); - task.addSplits(splits); - } - if (noMoreSplitsNotification.size() > 1) { - // The assumption that `noMoreSplitsNotification.size() <= 1` currently holds. - // If this assumption no longer holds, we should consider calling task.noMoreSplits with multiple entries in one shot. - // These kind of methods can be expensive since they are grabbing locks and/or sending HTTP requests on change. - throw new UnsupportedOperationException("This assumption no longer holds: noMoreSplitsNotification.size() < 1"); - } - for (Entry entry : noMoreSplitsNotification.entries()) { - task.noMoreSplits(entry.getKey(), entry.getValue()); - } - return newTasks.build(); - } - - private synchronized RemoteTask scheduleTask(InternalNode node, TaskId taskId, Multimap sourceSplits) - { - checkArgument(!allTasks.contains(taskId), "A task with id %s already exists", taskId); - - ImmutableMultimap.Builder initialSplits = ImmutableMultimap.builder(); - initialSplits.putAll(sourceSplits); - - sourceTasks.forEach((planNodeId, task) -> { - TaskStatus status = task.getTaskStatus(); - if (status.getState() != TaskState.FINISHED) { - initialSplits.put(planNodeId, createRemoteSplitFor(taskId, status.getSelf())); - } - }); - - OutputBuffers outputBuffers = this.outputBuffers.get(); - checkState(outputBuffers != null, "Initial output buffers must be set before a task can be scheduled"); - - RemoteTask task = remoteTaskFactory.createRemoteTask( - stateMachine.getSession(), - taskId, - node, - stateMachine.getFragment(), - initialSplits.build(), - outputBuffers, - nodeTaskMap.createPartitionedSplitCountTracker(node, taskId), - outboundDynamicFilterIds, - summarizeTaskInfo); - - completeSources.forEach(task::noMoreSplits); - - allTasks.add(taskId); - tasks.computeIfAbsent(node, key -> newConcurrentHashSet()).add(task); - nodeTaskMap.addTask(node, task); - - task.addStateChangeListener(new StageTaskListener()); - task.addFinalTaskInfoListener(this::updateFinalTaskInfo); - - if (!stateMachine.getState().isDone()) { - task.start(); - } - else { - // stage finished while we were scheduling this task - task.abort(); - } - - return task; - } - - public Set getScheduledNodes() - { - return ImmutableSet.copyOf(tasks.keySet()); - } - - public void recordGetSplitTime(long start) - { - stateMachine.recordGetSplitTime(start); - } - - private static Split createRemoteSplitFor(TaskId taskId, URI taskLocation) - { - // Fetch the results from the buffer assigned to the task based on id - URI splitLocation = uriBuilderFrom(taskLocation).appendPath("results").appendPath(String.valueOf(taskId.getId())).build(); - return new Split(REMOTE_CONNECTOR_ID, new RemoteSplit(splitLocation), Lifespan.taskWide()); - } - - private synchronized void updateTaskStatus(TaskStatus taskStatus) - { - try { - StageState stageState = getState(); - if (stageState.isDone()) { - return; - } - - TaskState taskState = taskStatus.getState(); - - switch (taskState) { - case FAILED: - RuntimeException failure = taskStatus.getFailures().stream() - .findFirst() - .map(this::rewriteTransportFailure) - .map(ExecutionFailureInfo::toException) - .orElse(new TrinoException(GENERIC_INTERNAL_ERROR, "A task failed for an unknown reason")); - stateMachine.transitionToFailed(failure); - break; - case ABORTED: - // A task should only be in the aborted state if the STAGE is done (ABORTED or FAILED) - stateMachine.transitionToFailed(new TrinoException(GENERIC_INTERNAL_ERROR, "A task is in the ABORTED state but stage is " + stageState)); - break; - case FLUSHING: - flushingTasks.add(taskStatus.getTaskId()); - break; - case FINISHED: - finishedTasks.add(taskStatus.getTaskId()); - flushingTasks.remove(taskStatus.getTaskId()); - break; - default: - } - - if (stageState == StageState.SCHEDULED || stageState == StageState.RUNNING || stageState == StageState.FLUSHING) { - if (taskState == TaskState.RUNNING) { - stateMachine.transitionToRunning(); - } - if (isFlushing()) { - stateMachine.transitionToFlushing(); - } - if (finishedTasks.containsAll(allTasks)) { - stateMachine.transitionToFinished(); - } - } - } - finally { - // after updating state, check if all tasks have final status information - checkAllTaskFinal(); - } - } - - private synchronized boolean isFlushing() - { - // to transition to flushing, there must be at least one flushing task, and all others must be flushing or finished. - return !flushingTasks.isEmpty() - && allTasks.stream().allMatch(taskId -> finishedTasks.contains(taskId) || flushingTasks.contains(taskId)); - } - - private synchronized void updateFinalTaskInfo(TaskInfo finalTaskInfo) - { - tasksWithFinalInfo.add(finalTaskInfo.getTaskStatus().getTaskId()); - checkAllTaskFinal(); - } - - private synchronized void checkAllTaskFinal() - { - if (stateMachine.getState().isDone() && tasksWithFinalInfo.containsAll(allTasks)) { - List finalTaskInfos = getAllTasks().stream() - .map(RemoteTask::getTaskInfo) - .collect(toImmutableList()); - stateMachine.setAllTasksFinal(finalTaskInfos); - } - } - - public List getTaskStatuses() - { - return getAllTasks().stream() - .map(RemoteTask::getTaskStatus) - .collect(toImmutableList()); - } - - public boolean isAnyTaskBlocked() - { - return getTaskStatuses().stream().anyMatch(TaskStatus::isOutputBufferOverutilized); - } - - private ExecutionFailureInfo rewriteTransportFailure(ExecutionFailureInfo executionFailureInfo) - { - if (executionFailureInfo.getRemoteHost() == null || failureDetector.getState(executionFailureInfo.getRemoteHost()) != GONE) { - return executionFailureInfo; - } - - return new ExecutionFailureInfo( - executionFailureInfo.getType(), - executionFailureInfo.getMessage(), - executionFailureInfo.getCause(), - executionFailureInfo.getSuppressed(), - executionFailureInfo.getStack(), - executionFailureInfo.getErrorLocation(), - REMOTE_HOST_GONE.toErrorCode(), - executionFailureInfo.getRemoteHost()); - } - - @Override - public String toString() - { - return stateMachine.toString(); - } - - private class StageTaskListener - implements StateChangeListener - { - private long previousUserMemory; - private long previousSystemMemory; - private long previousRevocableMemory; - private final Set completedDriverGroups = new HashSet<>(); - - @Override - public void stateChanged(TaskStatus taskStatus) - { - try { - updateMemoryUsage(taskStatus); - updateCompletedDriverGroups(taskStatus); - } - finally { - updateTaskStatus(taskStatus); - } - } - - private synchronized void updateMemoryUsage(TaskStatus taskStatus) - { - long currentUserMemory = taskStatus.getMemoryReservation().toBytes(); - long currentSystemMemory = taskStatus.getSystemMemoryReservation().toBytes(); - long currentRevocableMemory = taskStatus.getRevocableMemoryReservation().toBytes(); - long deltaUserMemoryInBytes = currentUserMemory - previousUserMemory; - long deltaRevocableMemoryInBytes = currentRevocableMemory - previousRevocableMemory; - long deltaTotalMemoryInBytes = (currentUserMemory + currentSystemMemory + currentRevocableMemory) - (previousUserMemory + previousSystemMemory + previousRevocableMemory); - previousUserMemory = currentUserMemory; - previousSystemMemory = currentSystemMemory; - previousRevocableMemory = currentRevocableMemory; - stateMachine.updateMemoryUsage(deltaUserMemoryInBytes, deltaRevocableMemoryInBytes, deltaTotalMemoryInBytes); - } - - private synchronized void updateCompletedDriverGroups(TaskStatus taskStatus) - { - // Sets.difference returns a view. - // Once we add the difference into `completedDriverGroups`, the view will be empty. - // `completedLifespansChangeListeners.invoke` happens asynchronously. - // As a result, calling the listeners before updating `completedDriverGroups` doesn't make a difference. - // That's why a copy must be made here. - Set newlyCompletedDriverGroups = ImmutableSet.copyOf(Sets.difference(taskStatus.getCompletedDriverGroups(), this.completedDriverGroups)); - if (newlyCompletedDriverGroups.isEmpty()) { - return; - } - completedLifespansChangeListeners.invoke(newlyCompletedDriverGroups, executor); - // newlyCompletedDriverGroups is a view. - // Making changes to completedDriverGroups will change newlyCompletedDriverGroups. - completedDriverGroups.addAll(newlyCompletedDriverGroups); - } - } - - private static class ListenerManager - { - private final List> listeners = new ArrayList<>(); - private boolean frozen; - - public synchronized void addListener(Consumer listener) - { - checkState(!frozen, "Listeners have been invoked"); - listeners.add(listener); - } - - public synchronized void invoke(T payload, Executor executor) - { - frozen = true; - for (Consumer listener : listeners) { - executor.execute(() -> listener.accept(payload)); - } - } - } -} diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlTask.java b/core/trino-main/src/main/java/io/trino/execution/SqlTask.java index 72ecb2420103..afdf36f9d3e4 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlTask.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlTask.java @@ -91,6 +91,7 @@ public class SqlTask private final AtomicReference taskHolderReference = new AtomicReference<>(new TaskHolder()); private final AtomicBoolean needsPlan = new AtomicBoolean(true); + private final AtomicReference traceToken = new AtomicReference<>(); public static SqlTask createSqlTask( TaskId taskId, @@ -425,6 +426,9 @@ public TaskInfo updateTask( Map dynamicFilterDomains) { try { + // trace token must be set first to make sure failure injection for getTaskResults requests works as expected + session.getTraceToken().ifPresent(traceToken::set); + // The LazyOutput buffer does not support write methods, so the actual // output buffer must be established before drivers are created (e.g. // a VALUES query). @@ -494,11 +498,12 @@ public TaskInfo abortTaskResults(OutputBufferId bufferId) return getTaskInfo(); } - public void failed(Throwable cause) + public TaskInfo failed(Throwable cause) { requireNonNull(cause, "cause is null"); taskStateMachine.failed(cause); + return getTaskInfo(); } public TaskInfo cancel() @@ -589,6 +594,11 @@ public void addStateChangeListener(StateChangeListener stateChangeLis taskStateMachine.addStateChangeListener(stateChangeListener); } + public void addSourceTaskFailureListener(TaskFailureListener listener) + { + taskStateMachine.addSourceTaskFailureListener(listener); + } + public QueryContext getQueryContext() { return queryContext; @@ -602,4 +612,9 @@ public Optional getTaskContext() } return Optional.of(taskExecution.getTaskContext()); } + + public Optional getTraceToken() + { + return Optional.ofNullable(traceToken.get()); + } } diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlTaskManager.java b/core/trino-main/src/main/java/io/trino/execution/SqlTaskManager.java index bdc29bb19bda..cfc69152e129 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlTaskManager.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlTaskManager.java @@ -463,6 +463,15 @@ public TaskInfo abortTask(TaskId taskId) return tasks.getUnchecked(taskId).abort(); } + @Override + public TaskInfo failTask(TaskId taskId, Throwable failure) + { + requireNonNull(taskId, "taskId is null"); + requireNonNull(failure, "failure is null"); + + return tasks.getUnchecked(taskId).failed(failure); + } + public void removeOldTasks() { DateTime oldestAllowedTask = DateTime.now().minus(infoCacheTime.toMillis()); @@ -533,6 +542,18 @@ public void addStateChangeListener(TaskId taskId, StateChangeListener tasks.getUnchecked(taskId).addStateChangeListener(stateChangeListener); } + @Override + public void addSourceTaskFailureListener(TaskId taskId, TaskFailureListener listener) + { + tasks.getUnchecked(taskId).addSourceTaskFailureListener(listener); + } + + @Override + public Optional getTraceToken(TaskId taskId) + { + return tasks.getUnchecked(taskId).getTraceToken(); + } + @VisibleForTesting public QueryContext getQueryContext(QueryId queryId) diff --git a/core/trino-main/src/main/java/io/trino/execution/StageInfo.java b/core/trino-main/src/main/java/io/trino/execution/StageInfo.java index d6f24b089cb5..ffd7172ab6bf 100644 --- a/core/trino-main/src/main/java/io/trino/execution/StageInfo.java +++ b/core/trino-main/src/main/java/io/trino/execution/StageInfo.java @@ -37,6 +37,7 @@ public class StageInfo private final StageId stageId; private final StageState state; private final PlanFragment plan; + private final boolean coordinatorOnly; private final List types; private final StageStats stageStats; private final List tasks; @@ -49,6 +50,7 @@ public StageInfo( @JsonProperty("stageId") StageId stageId, @JsonProperty("state") StageState state, @JsonProperty("plan") @Nullable PlanFragment plan, + @JsonProperty("coordinatorOnly") boolean coordinatorOnly, @JsonProperty("types") List types, @JsonProperty("stageStats") StageStats stageStats, @JsonProperty("tasks") List tasks, @@ -66,6 +68,7 @@ public StageInfo( this.stageId = stageId; this.state = state; this.plan = plan; + this.coordinatorOnly = coordinatorOnly; this.types = types; this.stageStats = stageStats; this.tasks = ImmutableList.copyOf(tasks); @@ -93,6 +96,12 @@ public PlanFragment getPlan() return plan; } + @JsonProperty + public boolean isCoordinatorOnly() + { + return coordinatorOnly; + } + @JsonProperty public List getTypes() { diff --git a/core/trino-main/src/main/java/io/trino/execution/StageState.java b/core/trino-main/src/main/java/io/trino/execution/StageState.java index 27013ba787a0..c543dd335cf7 100644 --- a/core/trino-main/src/main/java/io/trino/execution/StageState.java +++ b/core/trino-main/src/main/java/io/trino/execution/StageState.java @@ -31,31 +31,18 @@ public enum StageState * Stage tasks are being scheduled on nodes. */ SCHEDULING(false, false), - /** - * All stage tasks have been scheduled, but splits are still being scheduled. - */ - SCHEDULING_SPLITS(false, false), - /** - * Stage has been scheduled on nodes and ready to execute, but all tasks are still queued. - */ - SCHEDULED(false, false), /** * Stage is running. */ RUNNING(false, false), /** - * Stage has finished executing and output being consumed. - * In this state, at-least one of the tasks is flushing and the non-flushing tasks are finished + * Stage has finished executing existing tasks but more tasks could be scheduled in the future. */ - FLUSHING(false, false), + PENDING(false, false), /** * Stage has finished executing and all output has been consumed. */ FINISHED(true, false), - /** - * Stage was canceled by a user. - */ - CANCELED(true, false), /** * Stage was aborted due to a failure in the query. The failure * was not in this stage. @@ -93,29 +80,4 @@ public boolean isFailure() { return failureState; } - - public boolean canScheduleMoreTasks() - { - switch (this) { - case PLANNED: - case SCHEDULING: - // workers are still being added to the query - return true; - case SCHEDULING_SPLITS: - case SCHEDULED: - case RUNNING: - case FLUSHING: - case FINISHED: - case CANCELED: - // no more workers will be added to the query - return false; - case ABORTED: - case FAILED: - // DO NOT complete a FAILED or ABORTED stage. This will cause the - // stage above to finish normally, which will result in a query - // completing successfully when it should fail.. - return true; - } - throw new IllegalStateException("Unhandled state: " + this); - } } diff --git a/core/trino-main/src/main/java/io/trino/execution/StageStateMachine.java b/core/trino-main/src/main/java/io/trino/execution/StageStateMachine.java index 4acdac5184f1..ff3885dc369c 100644 --- a/core/trino-main/src/main/java/io/trino/execution/StageStateMachine.java +++ b/core/trino-main/src/main/java/io/trino/execution/StageStateMachine.java @@ -18,7 +18,6 @@ import io.airlift.log.Logger; import io.airlift.stats.Distribution; import io.airlift.units.Duration; -import io.trino.Session; import io.trino.execution.StateMachine.StateChangeListener; import io.trino.execution.scheduler.SplitSchedulerStats; import io.trino.operator.BlockedReason; @@ -53,15 +52,12 @@ import static io.airlift.units.DataSize.succinctBytes; import static io.airlift.units.Duration.succinctDuration; import static io.trino.execution.StageState.ABORTED; -import static io.trino.execution.StageState.CANCELED; import static io.trino.execution.StageState.FAILED; import static io.trino.execution.StageState.FINISHED; -import static io.trino.execution.StageState.FLUSHING; +import static io.trino.execution.StageState.PENDING; import static io.trino.execution.StageState.PLANNED; import static io.trino.execution.StageState.RUNNING; -import static io.trino.execution.StageState.SCHEDULED; import static io.trino.execution.StageState.SCHEDULING; -import static io.trino.execution.StageState.SCHEDULING_SPLITS; import static io.trino.execution.StageState.TERMINAL_STAGE_STATES; import static java.lang.Math.max; import static java.lang.Math.min; @@ -77,7 +73,6 @@ public class StageStateMachine private final StageId stageId; private final PlanFragment fragment; - private final Session session; private final Map tables; private final SplitSchedulerStats scheduledStats; @@ -96,14 +91,12 @@ public class StageStateMachine public StageStateMachine( StageId stageId, - Session session, PlanFragment fragment, Map tables, Executor executor, SplitSchedulerStats schedulerStats) { this.stageId = requireNonNull(stageId, "stageId is null"); - this.session = requireNonNull(session, "session is null"); this.fragment = requireNonNull(fragment, "fragment is null"); this.tables = ImmutableMap.copyOf(requireNonNull(tables, "tables is null")); this.scheduledStats = requireNonNull(schedulerStats, "schedulerStats is null"); @@ -119,11 +112,6 @@ public StageId getStageId() return stageId; } - public Session getSession() - { - return session; - } - public StageState getState() { return stageState.get(); @@ -144,30 +132,20 @@ public void addStateChangeListener(StateChangeListener stateChangeLi stageState.addStateChangeListener(stateChangeListener); } - public synchronized boolean transitionToScheduling() + public boolean transitionToScheduling() { return stageState.compareAndSet(PLANNED, SCHEDULING); } - public synchronized boolean transitionToSchedulingSplits() - { - return stageState.setIf(SCHEDULING_SPLITS, currentState -> currentState == PLANNED || currentState == SCHEDULING); - } - - public synchronized boolean transitionToScheduled() - { - schedulingComplete.compareAndSet(null, DateTime.now()); - return stageState.setIf(SCHEDULED, currentState -> currentState == PLANNED || currentState == SCHEDULING || currentState == SCHEDULING_SPLITS); - } - public boolean transitionToRunning() { - return stageState.setIf(RUNNING, currentState -> currentState != RUNNING && currentState != FLUSHING && !currentState.isDone()); + schedulingComplete.compareAndSet(null, DateTime.now()); + return stageState.setIf(RUNNING, currentState -> currentState != RUNNING && !currentState.isDone()); } - public boolean transitionToFlushing() + public boolean transitionToPending() { - return stageState.setIf(FLUSHING, currentState -> currentState != FLUSHING && !currentState.isDone()); + return stageState.setIf(PENDING, currentState -> currentState != PENDING && !currentState.isDone()); } public boolean transitionToFinished() @@ -175,11 +153,6 @@ public boolean transitionToFinished() return stageState.setIf(FINISHED, currentState -> !currentState.isDone()); } - public boolean transitionToCanceled() - { - return stageState.setIf(CANCELED, currentState -> !currentState.isDone()); - } - public boolean transitionToAborted() { return stageState.setIf(ABORTED, currentState -> !currentState.isDone()); @@ -259,7 +232,7 @@ public BasicStageStats getBasicStageStats(Supplier> taskInfos // information, the stage could finish, and the task states would // never be visible. StageState state = stageState.get(); - boolean isScheduled = state == RUNNING || state == FLUSHING || state.isDone(); + boolean isScheduled = state == RUNNING || state == StageState.PENDING || state.isDone(); List taskInfos = ImmutableList.copyOf(taskInfosSupplier.get()); @@ -382,6 +355,7 @@ public StageInfo getStageInfo(Supplier> taskInfosSupplier) int totalTasks = taskInfos.size(); int runningTasks = 0; int completedTasks = 0; + int failedTasks = 0; int totalDrivers = 0; int queuedDrivers = 0; @@ -439,6 +413,10 @@ public StageInfo getStageInfo(Supplier> taskInfosSupplier) runningTasks++; } + if (taskState == TaskState.FAILED) { + failedTasks++; + } + TaskStats taskStats = taskInfo.getStats(); totalDrivers += taskStats.getTotalDrivers(); @@ -507,6 +485,7 @@ public StageInfo getStageInfo(Supplier> taskInfosSupplier) totalTasks, runningTasks, completedTasks, + failedTasks, totalDrivers, queuedDrivers, @@ -559,9 +538,11 @@ public StageInfo getStageInfo(Supplier> taskInfosSupplier) if (state == FAILED) { failureInfo = failureCause.get(); } - return new StageInfo(stageId, + return new StageInfo( + stageId, state, fragment, + fragment.getPartitioning().isCoordinatorOnly(), fragment.getTypes(), stageStats, taskInfos, diff --git a/core/trino-main/src/main/java/io/trino/execution/StageStats.java b/core/trino-main/src/main/java/io/trino/execution/StageStats.java index 50d3613ac916..d7d1c6645913 100644 --- a/core/trino-main/src/main/java/io/trino/execution/StageStats.java +++ b/core/trino-main/src/main/java/io/trino/execution/StageStats.java @@ -32,7 +32,6 @@ import java.util.Set; import static com.google.common.base.Preconditions.checkArgument; -import static io.trino.execution.StageState.FLUSHING; import static io.trino.execution.StageState.RUNNING; import static java.lang.Math.min; import static java.util.Objects.requireNonNull; @@ -47,6 +46,7 @@ public class StageStats private final int totalTasks; private final int runningTasks; private final int completedTasks; + private final int failedTasks; private final int totalDrivers; private final int queuedDrivers; @@ -100,6 +100,7 @@ public StageStats( @JsonProperty("totalTasks") int totalTasks, @JsonProperty("runningTasks") int runningTasks, @JsonProperty("completedTasks") int completedTasks, + @JsonProperty("failedTasks") int failedTasks, @JsonProperty("totalDrivers") int totalDrivers, @JsonProperty("queuedDrivers") int queuedDrivers, @@ -153,6 +154,8 @@ public StageStats( this.runningTasks = runningTasks; checkArgument(completedTasks >= 0, "completedTasks is negative"); this.completedTasks = completedTasks; + checkArgument(failedTasks >= 0, "failedTasks is negative"); + this.failedTasks = failedTasks; checkArgument(totalDrivers >= 0, "totalDrivers is negative"); this.totalDrivers = totalDrivers; @@ -239,6 +242,12 @@ public int getCompletedTasks() return completedTasks; } + @JsonProperty + public int getFailedTasks() + { + return failedTasks; + } + @JsonProperty public int getTotalDrivers() { @@ -433,7 +442,7 @@ public List getOperatorSummaries() public BasicStageStats toBasicStageStats(StageState stageState) { - boolean isScheduled = stageState == RUNNING || stageState == FLUSHING || stageState.isDone(); + boolean isScheduled = stageState == RUNNING || stageState == StageState.PENDING || stageState.isDone(); OptionalDouble progressPercentage = OptionalDouble.empty(); if (isScheduled && totalDrivers != 0) { diff --git a/core/trino-main/src/main/java/io/trino/execution/TaskFailureListener.java b/core/trino-main/src/main/java/io/trino/execution/TaskFailureListener.java new file mode 100644 index 000000000000..d051911e0660 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/TaskFailureListener.java @@ -0,0 +1,19 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution; + +public interface TaskFailureListener +{ + void onTaskFailed(TaskId taskId, Throwable failure); +} diff --git a/core/trino-main/src/main/java/io/trino/execution/TaskId.java b/core/trino-main/src/main/java/io/trino/execution/TaskId.java index 9409433232e9..b0041f99953d 100644 --- a/core/trino-main/src/main/java/io/trino/execution/TaskId.java +++ b/core/trino-main/src/main/java/io/trino/execution/TaskId.java @@ -21,7 +21,9 @@ import java.util.Objects; import static com.google.common.base.Preconditions.checkArgument; +import static io.trino.spi.QueryId.parseDottedId; import static java.lang.Integer.parseInt; +import static java.util.Objects.requireNonNull; public class TaskId { @@ -33,37 +35,38 @@ public static TaskId valueOf(String taskId) private final String fullId; - public TaskId(String queryId, int stageId, int id) + public TaskId(StageId stageId, int partitionId, int attemptId) { - checkArgument(id >= 0, "id is negative"); - this.fullId = queryId + "." + stageId + "." + id; + requireNonNull(stageId, "stageId is null"); + checkArgument(partitionId >= 0, "partitionId is negative"); + checkArgument(attemptId >= 0, "attemptId is negative"); + this.fullId = stageId + "." + partitionId + "." + attemptId; } - public TaskId(StageId stageId, int id) + private TaskId(String fullId) { - checkArgument(id >= 0, "id is negative"); - this.fullId = stageId.getQueryId().getId() + "." + stageId.getId() + "." + id; - } - - public TaskId(String fullId) - { - this.fullId = fullId; + this.fullId = requireNonNull(fullId, "fullId is null"); } public QueryId getQueryId() { - return new QueryId(QueryId.parseDottedId(fullId, 3, "taskId").get(0)); + return new QueryId(parseDottedId(fullId, 4, "taskId").get(0)); } public StageId getStageId() { - List ids = QueryId.parseDottedId(fullId, 3, "taskId"); + List ids = parseDottedId(fullId, 4, "taskId"); return StageId.valueOf(ids.subList(0, 2)); } - public int getId() + public int getPartitionId() + { + return parseInt(parseDottedId(fullId, 4, "taskId").get(2)); + } + + public int getAttemptId() { - return parseInt(QueryId.parseDottedId(fullId, 3, "taskId").get(2)); + return parseInt(parseDottedId(fullId, 4, "taskId").get(3)); } @Override diff --git a/core/trino-main/src/main/java/io/trino/execution/TaskManager.java b/core/trino-main/src/main/java/io/trino/execution/TaskManager.java index 37d00984d213..8ae16c746663 100644 --- a/core/trino-main/src/main/java/io/trino/execution/TaskManager.java +++ b/core/trino-main/src/main/java/io/trino/execution/TaskManager.java @@ -108,6 +108,12 @@ TaskInfo updateTask( */ TaskInfo abortTask(TaskId taskId); + /** + * Fail a task. If the task does not already exist, it is created and then + * failed. + */ + TaskInfo failTask(TaskId taskId, Throwable failure); + /** * Gets results from a task either immediately or in the future. If the * task or buffer has not been created yet, an uninitialized task is @@ -140,4 +146,14 @@ TaskInfo updateTask( * possible notifications are observed out of order due to the asynchronous execution. */ void addStateChangeListener(TaskId taskId, StateChangeListener stateChangeListener); + + /** + * Add a listener that notifies about failures of any source tasks for a given task + */ + void addSourceTaskFailureListener(TaskId taskId, TaskFailureListener listener); + + /** + * Return trace token for a given task (see Session#traceToken) + */ + Optional getTraceToken(TaskId taskId); } diff --git a/core/trino-main/src/main/java/io/trino/execution/TaskStateMachine.java b/core/trino-main/src/main/java/io/trino/execution/TaskStateMachine.java index 3a0617708b5c..1e9b6e890dca 100644 --- a/core/trino-main/src/main/java/io/trino/execution/TaskStateMachine.java +++ b/core/trino-main/src/main/java/io/trino/execution/TaskStateMachine.java @@ -13,13 +13,20 @@ */ package io.trino.execution; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.ListenableFuture; import io.airlift.log.Logger; import io.trino.execution.StateMachine.StateChangeListener; import org.joda.time.DateTime; +import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import java.util.concurrent.Executor; import java.util.concurrent.LinkedBlockingQueue; @@ -39,12 +46,19 @@ public class TaskStateMachine private final DateTime createdTime = DateTime.now(); private final TaskId taskId; + private final Executor executor; private final StateMachine taskState; private final LinkedBlockingQueue failureCauses = new LinkedBlockingQueue<>(); + @GuardedBy("this") + private final Map sourceTaskFailures = new HashMap<>(); + @GuardedBy("this") + private final List sourceTaskFailureListeners = new ArrayList<>(); + public TaskStateMachine(TaskId taskId, Executor executor) { this.taskId = requireNonNull(taskId, "taskId is null"); + this.executor = requireNonNull(executor, "executor is null"); taskState = new StateMachine<>("task " + taskId, executor, TaskState.RUNNING, TERMINAL_TASK_STATES); taskState.addStateChangeListener(newState -> log.debug("Task %s is %s", taskId, newState)); } @@ -126,6 +140,32 @@ public void addStateChangeListener(StateChangeListener stateChangeLis taskState.addStateChangeListener(stateChangeListener); } + public void addSourceTaskFailureListener(TaskFailureListener listener) + { + Map failures; + synchronized (this) { + sourceTaskFailureListeners.add(listener); + failures = ImmutableMap.copyOf(sourceTaskFailures); + } + executor.execute(() -> { + failures.forEach(listener::onTaskFailed); + }); + } + + public void sourceTaskFailed(TaskId taskId, Throwable failure) + { + List listeners; + synchronized (this) { + sourceTaskFailures.putIfAbsent(taskId, failure); + listeners = ImmutableList.copyOf(sourceTaskFailureListeners); + } + executor.execute(() -> { + for (TaskFailureListener listener : listeners) { + listener.onTaskFailed(taskId, failure); + } + }); + } + @Override public String toString() { diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/AllAtOnceExecutionPolicy.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/AllAtOnceExecutionPolicy.java index 34e7ed2bdb9c..c12272190cfe 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/AllAtOnceExecutionPolicy.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/AllAtOnceExecutionPolicy.java @@ -13,15 +13,13 @@ */ package io.trino.execution.scheduler; -import io.trino.execution.SqlStageExecution; - import java.util.Collection; public class AllAtOnceExecutionPolicy implements ExecutionPolicy { @Override - public ExecutionSchedule createExecutionSchedule(Collection stages) + public ExecutionSchedule createExecutionSchedule(Collection stages) { return new AllAtOnceExecutionSchedule(stages); } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/AllAtOnceExecutionSchedule.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/AllAtOnceExecutionSchedule.java index 8d46987a8bff..ff67f1938c31 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/AllAtOnceExecutionSchedule.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/AllAtOnceExecutionSchedule.java @@ -17,8 +17,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Ordering; -import io.trino.execution.SqlStageExecution; -import io.trino.execution.StageState; import io.trino.sql.planner.PlanFragment; import io.trino.sql.planner.plan.ExchangeNode; import io.trino.sql.planner.plan.IndexJoinNode; @@ -42,36 +40,35 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; -import static com.google.common.collect.Iterables.getOnlyElement; -import static io.trino.execution.StageState.FLUSHING; -import static io.trino.execution.StageState.RUNNING; -import static io.trino.execution.StageState.SCHEDULED; +import static io.trino.execution.scheduler.PipelinedStageExecution.State.FLUSHING; +import static io.trino.execution.scheduler.PipelinedStageExecution.State.RUNNING; +import static io.trino.execution.scheduler.PipelinedStageExecution.State.SCHEDULED; import static java.util.Objects.requireNonNull; import static java.util.function.Function.identity; public class AllAtOnceExecutionSchedule implements ExecutionSchedule { - private final Set schedulingStages; + private final Set schedulingStages; - public AllAtOnceExecutionSchedule(Collection stages) + public AllAtOnceExecutionSchedule(Collection stages) { requireNonNull(stages, "stages is null"); List preferredScheduleOrder = getPreferredScheduleOrder(stages.stream() - .map(SqlStageExecution::getFragment) + .map(PipelinedStageExecution::getFragment) .collect(toImmutableList())); - Ordering ordering = Ordering.explicit(preferredScheduleOrder) + Ordering ordering = Ordering.explicit(preferredScheduleOrder) .onResultOf(PlanFragment::getId) - .onResultOf(SqlStageExecution::getFragment); + .onResultOf(PipelinedStageExecution::getFragment); schedulingStages = new LinkedHashSet<>(ordering.sortedCopy(stages)); } @Override - public Set getStagesToSchedule() + public Set getStagesToSchedule() { - for (Iterator iterator = schedulingStages.iterator(); iterator.hasNext(); ) { - StageState state = iterator.next().getState(); + for (Iterator iterator = schedulingStages.iterator(); iterator.hasNext(); ) { + PipelinedStageExecution.State state = iterator.next().getState(); if (state == SCHEDULED || state == RUNNING || state == FLUSHING || state.isDone()) { iterator.remove(); } @@ -99,10 +96,9 @@ static List getPreferredScheduleOrder(Collection f Set rootFragments = fragments.stream() .filter(fragment -> !remoteSources.contains(fragment.getId())) .collect(toImmutableSet()); - checkArgument(rootFragments.size() == 1, "Expected one root fragment, but found: %s", rootFragments); Visitor visitor = new Visitor(fragments); - visitor.processFragment(getOnlyElement(rootFragments).getId()); + rootFragments.forEach(fragment -> visitor.processFragment(fragment.getId())); return visitor.getSchedulerOrder(); } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/BroadcastOutputBufferManager.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/BroadcastOutputBufferManager.java index af4da83997e9..d66cce34772d 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/BroadcastOutputBufferManager.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/BroadcastOutputBufferManager.java @@ -19,57 +19,46 @@ import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; -import java.util.List; -import java.util.function.Consumer; - import static io.trino.execution.buffer.OutputBuffers.BROADCAST_PARTITION_ID; import static io.trino.execution.buffer.OutputBuffers.BufferType.BROADCAST; import static io.trino.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers; -import static java.util.Objects.requireNonNull; @ThreadSafe class BroadcastOutputBufferManager implements OutputBufferManager { - private final Consumer outputBufferTarget; - @GuardedBy("this") private OutputBuffers outputBuffers = createInitialEmptyOutputBuffers(BROADCAST); - public BroadcastOutputBufferManager(Consumer outputBufferTarget) - { - this.outputBufferTarget = requireNonNull(outputBufferTarget, "outputBufferTarget is null"); - outputBufferTarget.accept(outputBuffers); - } - @Override - public void addOutputBuffers(List newBuffers, boolean noMoreBuffers) + public synchronized void addOutputBuffer(OutputBufferId newBuffer) { - OutputBuffers newOutputBuffers; - synchronized (this) { - if (outputBuffers.isNoMoreBufferIds()) { - // a stage can move to a final state (e.g., failed) while scheduling, so ignore - // the new buffers - return; - } - - OutputBuffers originalOutputBuffers = outputBuffers; + if (outputBuffers.isNoMoreBufferIds()) { + // a stage can move to a final state (e.g., failed) while scheduling, so ignore + // the new buffers + return; + } - // Note: it does not matter which partition id the task is using, in broadcast all tasks read from the same partition - for (OutputBufferId newBuffer : newBuffers) { - outputBuffers = outputBuffers.withBuffer(newBuffer, BROADCAST_PARTITION_ID); - } + // Note: it does not matter which partition id the task is using, in broadcast all tasks read from the same partition + OutputBuffers newOutputBuffers = outputBuffers.withBuffer(newBuffer, BROADCAST_PARTITION_ID); - if (noMoreBuffers) { - outputBuffers = outputBuffers.withNoMoreBufferIds(); - } + // don't update if nothing changed + if (newOutputBuffers != outputBuffers) { + this.outputBuffers = newOutputBuffers; + } + } - // don't update if nothing changed - if (outputBuffers == originalOutputBuffers) { - return; - } - newOutputBuffers = this.outputBuffers; + @Override + public synchronized void noMoreBuffers() + { + if (!outputBuffers.isNoMoreBufferIds()) { + outputBuffers = outputBuffers.withNoMoreBufferIds(); } - outputBufferTarget.accept(newOutputBuffers); + } + + @Override + public synchronized OutputBuffers getOutputBuffers() + { + return outputBuffers; } } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/ExecutionPolicy.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/ExecutionPolicy.java index 46c91e7e919e..6c7c2b7bbc4b 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/ExecutionPolicy.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/ExecutionPolicy.java @@ -13,11 +13,9 @@ */ package io.trino.execution.scheduler; -import io.trino.execution.SqlStageExecution; - import java.util.Collection; public interface ExecutionPolicy { - ExecutionSchedule createExecutionSchedule(Collection stages); + ExecutionSchedule createExecutionSchedule(Collection stages); } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/ExecutionSchedule.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/ExecutionSchedule.java index 1b5096d2f5dc..221f975570a0 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/ExecutionSchedule.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/ExecutionSchedule.java @@ -13,13 +13,11 @@ */ package io.trino.execution.scheduler; -import io.trino.execution.SqlStageExecution; - import java.util.Set; public interface ExecutionSchedule { - Set getStagesToSchedule(); + Set getStagesToSchedule(); boolean isFinished(); } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedCountScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedCountScheduler.java index 6f8bf8e4ed34..04998fca219e 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedCountScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedCountScheduler.java @@ -14,8 +14,8 @@ package io.trino.execution.scheduler; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableMultimap; import io.trino.execution.RemoteTask; -import io.trino.execution.SqlStageExecution; import io.trino.metadata.InternalNode; import java.util.List; @@ -36,10 +36,10 @@ public interface TaskScheduler private final TaskScheduler taskScheduler; private final List partitionToNode; - public FixedCountScheduler(SqlStageExecution stage, List partitionToNode) + public FixedCountScheduler(PipelinedStageExecution stageExecution, List partitionToNode) { - requireNonNull(stage, "stage is null"); - this.taskScheduler = stage::scheduleTask; + requireNonNull(stageExecution, "stage is null"); + this.taskScheduler = (node, partition) -> stageExecution.scheduleTask(node, partition, ImmutableMultimap.of(), ImmutableMultimap.of()); this.partitionToNode = requireNonNull(partitionToNode, "partitionToNode is null"); } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedSourcePartitionedScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedSourcePartitionedScheduler.java index 6cc76e0d83a6..f5cb581bd1d5 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedSourcePartitionedScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedSourcePartitionedScheduler.java @@ -14,13 +14,12 @@ package io.trino.execution.scheduler; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Streams; import com.google.common.util.concurrent.ListenableFuture; import io.airlift.log.Logger; import io.trino.execution.Lifespan; import io.trino.execution.RemoteTask; -import io.trino.execution.SqlStageExecution; import io.trino.execution.TableExecuteContextManager; import io.trino.execution.scheduler.ScheduleResult.BlockedReason; import io.trino.execution.scheduler.group.DynamicLifespanScheduler; @@ -35,6 +34,7 @@ import io.trino.sql.planner.plan.PlanNodeId; import java.util.ArrayList; +import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -46,11 +46,9 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; -import static com.google.common.collect.ImmutableList.toImmutableList; import static io.airlift.concurrent.MoreFutures.whenAnyComplete; import static io.trino.execution.scheduler.SourcePartitionedScheduler.newSourcePartitionedSchedulerAsSourceScheduler; import static io.trino.spi.connector.NotPartitionedPartitionHandle.NOT_PARTITIONED; -import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; public class FixedSourcePartitionedScheduler @@ -58,15 +56,17 @@ public class FixedSourcePartitionedScheduler { private static final Logger log = Logger.get(FixedSourcePartitionedScheduler.class); - private final SqlStageExecution stage; + private final PipelinedStageExecution stageExecution; private final List nodes; private final List sourceSchedulers; private final List partitionHandles; - private boolean scheduledTasks; private final Optional groupedLifespanScheduler; + private final PartitionIdAllocator partitionIdAllocator; + private final Map scheduledTasks; + public FixedSourcePartitionedScheduler( - SqlStageExecution stage, + PipelinedStageExecution stageExecution, Map splitSources, StageExecutionDescriptor stageExecutionDescriptor, List schedulingOrder, @@ -79,20 +79,20 @@ public FixedSourcePartitionedScheduler( DynamicFilterService dynamicFilterService, TableExecuteContextManager tableExecuteContextManager) { - requireNonNull(stage, "stage is null"); + requireNonNull(stageExecution, "stageExecution is null"); requireNonNull(splitSources, "splitSources is null"); requireNonNull(bucketNodeMap, "bucketNodeMap is null"); checkArgument(!requireNonNull(nodes, "nodes is null").isEmpty(), "nodes is empty"); requireNonNull(partitionHandles, "partitionHandles is null"); requireNonNull(tableExecuteContextManager, "tableExecuteContextManager is null"); - this.stage = stage; + this.stageExecution = stageExecution; this.nodes = ImmutableList.copyOf(nodes); this.partitionHandles = ImmutableList.copyOf(partitionHandles); checkArgument(splitSources.keySet().equals(ImmutableSet.copyOf(schedulingOrder))); - BucketedSplitPlacementPolicy splitPlacementPolicy = new BucketedSplitPlacementPolicy(nodeSelector, nodes, bucketNodeMap, stage::getAllTasks); + BucketedSplitPlacementPolicy splitPlacementPolicy = new BucketedSplitPlacementPolicy(nodeSelector, nodes, bucketNodeMap, stageExecution::getAllTasks); ArrayList sourceSchedulers = new ArrayList<>(); checkArgument( @@ -109,13 +109,16 @@ public FixedSourcePartitionedScheduler( boolean firstPlanNode = true; Optional groupedLifespanScheduler = Optional.empty(); + + partitionIdAllocator = new PartitionIdAllocator(); + scheduledTasks = new HashMap<>(); for (PlanNodeId planNodeId : schedulingOrder) { SplitSource splitSource = splitSources.get(planNodeId); boolean groupedExecutionForScanNode = stageExecutionDescriptor.isScanGroupedExecution(planNodeId); // TODO : change anySourceTaskBlocked to accommodate the correct blocked status of source tasks // (ref : https://github.com/trinodb/trino/issues/4713) SourceScheduler sourceScheduler = newSourcePartitionedSchedulerAsSourceScheduler( - stage, + stageExecution, planNodeId, splitSource, splitPlacementPolicy, @@ -123,7 +126,9 @@ public FixedSourcePartitionedScheduler( groupedExecutionForScanNode, dynamicFilterService, tableExecuteContextManager, - () -> true); + () -> true, + partitionIdAllocator, + scheduledTasks); if (stageExecutionDescriptor.isStageGroupedExecution() && !groupedExecutionForScanNode) { sourceScheduler = new AsGroupedSourceScheduler(sourceScheduler); @@ -153,7 +158,7 @@ public FixedSourcePartitionedScheduler( // Schedule the first few lifespans lifespanScheduler.scheduleInitial(sourceScheduler); // Schedule new lifespans for finished ones - stage.addCompletedDriverGroupsChangedListener(lifespanScheduler::onLifespanFinished); + stageExecution.addCompletedDriverGroupsChangedListener(lifespanScheduler::onLifespanFinished); groupedLifespanScheduler = Optional.of(lifespanScheduler); } } @@ -175,14 +180,16 @@ public ScheduleResult schedule() { // schedule a task on every node in the distribution List newTasks = ImmutableList.of(); - if (!scheduledTasks) { - newTasks = Streams.mapWithIndex( - nodes.stream(), - (node, id) -> stage.scheduleTask(node, toIntExact(id))) - .filter(Optional::isPresent) - .map(Optional::get) - .collect(toImmutableList()); - scheduledTasks = true; + if (scheduledTasks.isEmpty()) { + ImmutableList.Builder newTasksBuilder = ImmutableList.builder(); + for (InternalNode node : nodes) { + Optional task = stageExecution.scheduleTask(node, partitionIdAllocator.getNextId(), ImmutableMultimap.of(), ImmutableMultimap.of()); + if (task.isPresent()) { + scheduledTasks.put(node, task.get()); + newTasksBuilder.add(task.get()); + } + } + newTasks = newTasksBuilder.build(); } boolean allBlocked = true; @@ -226,7 +233,7 @@ public ScheduleResult schedule() driverGroupsToStart = sourceScheduler.drainCompletedLifespans(); if (schedule.isFinished()) { - stage.schedulingComplete(sourceScheduler.getPlanNodeId()); + stageExecution.schedulingComplete(sourceScheduler.getPlanNodeId()); schedulerIterator.remove(); sourceScheduler.close(); shouldInvokeNoMoreDriverGroups = true; diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/OutputBufferManager.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/OutputBufferManager.java index c655e2487426..d3a98bcf3a0f 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/OutputBufferManager.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/OutputBufferManager.java @@ -13,11 +13,14 @@ */ package io.trino.execution.scheduler; +import io.trino.execution.buffer.OutputBuffers; import io.trino.execution.buffer.OutputBuffers.OutputBufferId; -import java.util.List; - interface OutputBufferManager { - void addOutputBuffers(List newBuffers, boolean noMoreBuffers); + void addOutputBuffer(OutputBufferId newBuffer); + + void noMoreBuffers(); + + OutputBuffers getOutputBuffers(); } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/PartitionIdAllocator.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/PartitionIdAllocator.java new file mode 100644 index 000000000000..e16123d0ced7 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/PartitionIdAllocator.java @@ -0,0 +1,26 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import java.util.concurrent.atomic.AtomicInteger; + +public class PartitionIdAllocator +{ + private final AtomicInteger nextId = new AtomicInteger(); + + public int getNextId() + { + return nextId.getAndIncrement(); + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/PartitionedOutputBufferManager.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/PartitionedOutputBufferManager.java index bad4a9e1e162..21a4df964fb8 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/PartitionedOutputBufferManager.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/PartitionedOutputBufferManager.java @@ -20,10 +20,6 @@ import javax.annotation.concurrent.ThreadSafe; -import java.util.List; -import java.util.Map; -import java.util.function.Consumer; - import static com.google.common.base.Preconditions.checkArgument; import static io.trino.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers; import static java.util.Objects.requireNonNull; @@ -32,9 +28,9 @@ public class PartitionedOutputBufferManager implements OutputBufferManager { - private final Map outputBuffers; + private final OutputBuffers outputBuffers; - public PartitionedOutputBufferManager(PartitioningHandle partitioningHandle, int partitionCount, Consumer outputBufferTarget) + public PartitionedOutputBufferManager(PartitioningHandle partitioningHandle, int partitionCount) { checkArgument(partitionCount >= 1, "partitionCount must be at least 1"); @@ -43,27 +39,31 @@ public PartitionedOutputBufferManager(PartitioningHandle partitioningHandle, int partitions.put(new OutputBufferId(partition), partition); } - OutputBuffers outputBuffers = createInitialEmptyOutputBuffers(requireNonNull(partitioningHandle, "partitioningHandle is null")) + outputBuffers = createInitialEmptyOutputBuffers(requireNonNull(partitioningHandle, "partitioningHandle is null")) .withBuffers(partitions.build()) .withNoMoreBufferIds(); - outputBufferTarget.accept(outputBuffers); - - this.outputBuffers = outputBuffers.getBuffers(); } @Override - public void addOutputBuffers(List newBuffers, boolean noMoreBuffers) + public void addOutputBuffer(OutputBufferId newBuffer) { // All buffers are created in the constructor, so just validate that this isn't // a request to add a new buffer - for (OutputBufferId newBuffer : newBuffers) { - Integer existingBufferId = outputBuffers.get(newBuffer); - if (existingBufferId == null) { - throw new IllegalStateException("Unexpected new output buffer " + newBuffer); - } - if (newBuffer.getId() != existingBufferId) { - throw new IllegalStateException("newOutputBuffers has changed the assignment for task " + newBuffer); - } + Integer existingBufferId = outputBuffers.getBuffers().get(newBuffer); + if (existingBufferId == null) { + throw new IllegalStateException("Unexpected new output buffer " + newBuffer); + } + if (newBuffer.getId() != existingBufferId) { + throw new IllegalStateException("newOutputBuffers has changed the assignment for task " + newBuffer); } } + + @Override + public void noMoreBuffers() {} + + @Override + public OutputBuffers getOutputBuffers() + { + return outputBuffers; + } } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/PhasedExecutionPolicy.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/PhasedExecutionPolicy.java index 99190392603d..626f7e34d4f2 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/PhasedExecutionPolicy.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/PhasedExecutionPolicy.java @@ -13,15 +13,13 @@ */ package io.trino.execution.scheduler; -import io.trino.execution.SqlStageExecution; - import java.util.Collection; public class PhasedExecutionPolicy implements ExecutionPolicy { @Override - public ExecutionSchedule createExecutionSchedule(Collection stages) + public ExecutionSchedule createExecutionSchedule(Collection stages) { return new PhasedExecutionSchedule(stages); } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/PhasedExecutionSchedule.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/PhasedExecutionSchedule.java index f1b14d289021..fa0025305241 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/PhasedExecutionSchedule.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/PhasedExecutionSchedule.java @@ -16,8 +16,6 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import io.trino.execution.SqlStageExecution; -import io.trino.execution.StageState; import io.trino.sql.planner.PlanFragment; import io.trino.sql.planner.plan.ExchangeNode; import io.trino.sql.planner.plan.IndexJoinNode; @@ -50,9 +48,9 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; -import static io.trino.execution.StageState.FLUSHING; -import static io.trino.execution.StageState.RUNNING; -import static io.trino.execution.StageState.SCHEDULED; +import static io.trino.execution.scheduler.PipelinedStageExecution.State.FLUSHING; +import static io.trino.execution.scheduler.PipelinedStageExecution.State.RUNNING; +import static io.trino.execution.scheduler.PipelinedStageExecution.State.SCHEDULED; import static io.trino.sql.planner.plan.ExchangeNode.Scope.LOCAL; import static java.util.function.Function.identity; @@ -60,14 +58,14 @@ public class PhasedExecutionSchedule implements ExecutionSchedule { - private final List> schedulePhases; - private final Set activeSources = new HashSet<>(); + private final List> schedulePhases; + private final Set activeSources = new HashSet<>(); - public PhasedExecutionSchedule(Collection stages) + public PhasedExecutionSchedule(Collection stages) { - List> phases = extractPhases(stages.stream().map(SqlStageExecution::getFragment).collect(toImmutableList())); + List> phases = extractPhases(stages.stream().map(PipelinedStageExecution::getFragment).collect(toImmutableList())); - Map stagesByFragmentId = stages.stream().collect(toImmutableMap(stage -> stage.getFragment().getId(), identity())); + Map stagesByFragmentId = stages.stream().collect(toImmutableMap(stage -> stage.getFragment().getId(), identity())); // create a mutable list of mutable sets of stages, so we can remove completed stages schedulePhases = new ArrayList<>(); @@ -79,7 +77,7 @@ public PhasedExecutionSchedule(Collection stages) } @Override - public Set getStagesToSchedule() + public Set getStagesToSchedule() { removeCompletedStages(); addPhasesIfNecessary(); @@ -91,8 +89,8 @@ public Set getStagesToSchedule() private void removeCompletedStages() { - for (Iterator stageIterator = activeSources.iterator(); stageIterator.hasNext(); ) { - StageState state = stageIterator.next().getState(); + for (Iterator stageIterator = activeSources.iterator(); stageIterator.hasNext(); ) { + PipelinedStageExecution.State state = stageIterator.next().getState(); if (state == SCHEDULED || state == RUNNING || state == FLUSHING || state.isDone()) { stageIterator.remove(); } @@ -107,7 +105,7 @@ private void addPhasesIfNecessary() } while (!schedulePhases.isEmpty()) { - Set phase = schedulePhases.remove(0); + Set phase = schedulePhases.remove(0); activeSources.addAll(phase); if (hasSourceDistributedStage(phase)) { return; @@ -115,7 +113,7 @@ private void addPhasesIfNecessary() } } - private static boolean hasSourceDistributedStage(Set phase) + private static boolean hasSourceDistributedStage(Set phase) { return phase.stream().anyMatch(stage -> !stage.getFragment().getPartitionedSources().isEmpty()); } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedStageExecution.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedStageExecution.java new file mode 100644 index 000000000000..d9812b0a2233 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedStageExecution.java @@ -0,0 +1,735 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import com.google.common.collect.HashMultimap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableMultimap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Multimap; +import com.google.common.collect.Sets; +import io.airlift.log.Logger; +import io.trino.execution.ExecutionFailureInfo; +import io.trino.execution.Lifespan; +import io.trino.execution.RemoteTask; +import io.trino.execution.SqlStage; +import io.trino.execution.StageId; +import io.trino.execution.StateMachine; +import io.trino.execution.StateMachine.StateChangeListener; +import io.trino.execution.TaskId; +import io.trino.execution.TaskState; +import io.trino.execution.TaskStatus; +import io.trino.execution.buffer.OutputBuffers; +import io.trino.execution.buffer.OutputBuffers.OutputBufferId; +import io.trino.failuredetector.FailureDetector; +import io.trino.metadata.InternalNode; +import io.trino.metadata.Split; +import io.trino.spi.TrinoException; +import io.trino.split.RemoteSplit; +import io.trino.sql.planner.PlanFragment; +import io.trino.sql.planner.plan.PlanFragmentId; +import io.trino.sql.planner.plan.PlanNodeId; +import io.trino.sql.planner.plan.RemoteSourceNode; +import io.trino.util.Failures; +import org.joda.time.DateTime; + +import javax.annotation.concurrent.GuardedBy; + +import java.net.URI; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.stream.Stream; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.airlift.http.client.HttpUriBuilder.uriBuilderFrom; +import static io.trino.execution.scheduler.PipelinedStageExecution.State.ABORTED; +import static io.trino.execution.scheduler.PipelinedStageExecution.State.CANCELED; +import static io.trino.execution.scheduler.PipelinedStageExecution.State.FAILED; +import static io.trino.execution.scheduler.PipelinedStageExecution.State.FINISHED; +import static io.trino.execution.scheduler.PipelinedStageExecution.State.FLUSHING; +import static io.trino.execution.scheduler.PipelinedStageExecution.State.PLANNED; +import static io.trino.execution.scheduler.PipelinedStageExecution.State.RUNNING; +import static io.trino.execution.scheduler.PipelinedStageExecution.State.SCHEDULED; +import static io.trino.execution.scheduler.PipelinedStageExecution.State.SCHEDULING; +import static io.trino.execution.scheduler.PipelinedStageExecution.State.SCHEDULING_SPLITS; +import static io.trino.failuredetector.FailureDetector.State.GONE; +import static io.trino.operator.ExchangeOperator.REMOTE_CONNECTOR_ID; +import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; +import static io.trino.spi.StandardErrorCode.REMOTE_HOST_GONE; +import static java.util.Objects.requireNonNull; + +public class PipelinedStageExecution +{ + private static final Logger log = Logger.get(PipelinedStageExecution.class); + + private final PipelinedStageStateMachine stateMachine; + private final SqlStage stage; + private final Map outputBufferManagers; + private final TaskLifecycleListener taskLifecycleListener; + private final FailureDetector failureDetector; + private final Executor executor; + private final Optional bucketToPartition; + private final Map exchangeSources; + private final int attempt; + + private final Map tasks = new ConcurrentHashMap<>(); + + // current stage task tracking + @GuardedBy("this") + private final Set allTasks = new HashSet<>(); + @GuardedBy("this") + private final Set finishedTasks = new HashSet<>(); + @GuardedBy("this") + private final Set flushingTasks = new HashSet<>(); + + // source task tracking + @GuardedBy("this") + private final Multimap sourceTasks = HashMultimap.create(); + @GuardedBy("this") + private final Set completeSourceFragments = new HashSet<>(); + @GuardedBy("this") + private final Set completeSources = new HashSet<>(); + + // lifespan tracking + private final Set completedDriverGroups = new HashSet<>(); + private final ListenerManager> completedLifespansChangeListeners = new ListenerManager<>(); + + public static PipelinedStageExecution createPipelinedStageExecution( + SqlStage stage, + Map outputBufferManagers, + TaskLifecycleListener taskLifecycleListener, + FailureDetector failureDetector, + Executor executor, + Optional bucketToPartition, + int attempt) + { + PipelinedStageStateMachine stateMachine = new PipelinedStageStateMachine(stage.getStageId(), executor); + ImmutableMap.Builder exchangeSources = ImmutableMap.builder(); + for (RemoteSourceNode remoteSourceNode : stage.getFragment().getRemoteSourceNodes()) { + for (PlanFragmentId planFragmentId : remoteSourceNode.getSourceFragmentIds()) { + exchangeSources.put(planFragmentId, remoteSourceNode); + } + } + PipelinedStageExecution execution = new PipelinedStageExecution( + stateMachine, + stage, + outputBufferManagers, + taskLifecycleListener, + failureDetector, + executor, + bucketToPartition, + exchangeSources.build(), + attempt); + execution.initialize(); + return execution; + } + + private PipelinedStageExecution( + PipelinedStageStateMachine stateMachine, + SqlStage stage, + Map outputBufferManagers, + TaskLifecycleListener taskLifecycleListener, + FailureDetector failureDetector, + Executor executor, + Optional bucketToPartition, + Map exchangeSources, + int attempt) + { + this.stateMachine = requireNonNull(stateMachine, "stateMachine is null"); + this.stage = requireNonNull(stage, "stage is null"); + this.outputBufferManagers = ImmutableMap.copyOf(requireNonNull(outputBufferManagers, "outputBufferManagers is null")); + this.taskLifecycleListener = requireNonNull(taskLifecycleListener, "taskLifecycleListener is null"); + this.failureDetector = requireNonNull(failureDetector, "failureDetector is null"); + this.executor = requireNonNull(executor, "executor is null"); + this.bucketToPartition = requireNonNull(bucketToPartition, "bucketToPartition is null"); + this.exchangeSources = ImmutableMap.copyOf(requireNonNull(exchangeSources, "exchangeSources is null")); + this.attempt = attempt; + } + + private void initialize() + { + stateMachine.addStateChangeListener(state -> { + if (!state.canScheduleMoreTasks()) { + taskLifecycleListener.noMoreTasks(stage.getFragment().getId()); + + // update output buffers + for (PlanFragmentId sourceFragment : exchangeSources.keySet()) { + OutputBufferManager outputBufferManager = outputBufferManagers.get(sourceFragment); + outputBufferManager.noMoreBuffers(); + for (RemoteTask sourceTask : sourceTasks.get(stage.getFragment().getId())) { + sourceTask.setOutputBuffers(outputBufferManager.getOutputBuffers()); + } + } + } + }); + } + + public State getState() + { + return stateMachine.getState(); + } + + /** + * Listener is always notified asynchronously using a dedicated notification thread pool so, care should + * be taken to avoid leaking {@code this} when adding a listener in a constructor. + */ + public void addStateChangeListener(StateChangeListener stateChangeListener) + { + stateMachine.addStateChangeListener(stateChangeListener); + } + + public void addCompletedDriverGroupsChangedListener(Consumer> newlyCompletedDriverGroupConsumer) + { + completedLifespansChangeListeners.addListener(newlyCompletedDriverGroupConsumer); + } + + public synchronized void beginScheduling() + { + stateMachine.transitionToScheduling(); + } + + public synchronized void transitionToSchedulingSplits() + { + stateMachine.transitionToSchedulingSplits(); + } + + public synchronized void schedulingComplete() + { + if (!stateMachine.transitionToScheduled()) { + return; + } + + if (isFlushing()) { + stateMachine.transitionToFlushing(); + } + if (finishedTasks.containsAll(allTasks)) { + stateMachine.transitionToFinished(); + } + + for (PlanNodeId partitionedSource : stage.getFragment().getPartitionedSources()) { + schedulingComplete(partitionedSource); + } + } + + private synchronized boolean isFlushing() + { + // to transition to flushing, there must be at least one flushing task, and all others must be flushing or finished. + return !flushingTasks.isEmpty() + && allTasks.stream().allMatch(taskId -> finishedTasks.contains(taskId) || flushingTasks.contains(taskId)); + } + + public synchronized void schedulingComplete(PlanNodeId partitionedSource) + { + for (RemoteTask task : getAllTasks()) { + task.noMoreSplits(partitionedSource); + } + completeSources.add(partitionedSource); + } + + public synchronized void cancel() + { + stateMachine.transitionToCanceled(); + getAllTasks().forEach(RemoteTask::cancel); + } + + public synchronized void abort() + { + stateMachine.transitionToAborted(); + getAllTasks().forEach(RemoteTask::abort); + } + + public synchronized void fail(Throwable failureCause) + { + stateMachine.transitionToFailed(failureCause); + tasks.values().forEach(RemoteTask::abort); + } + + public synchronized void failTask(TaskId taskId, Throwable failureCause) + { + RemoteTask task = requireNonNull(tasks.get(taskId.getPartitionId()), () -> "task not found: " + taskId); + task.fail(failureCause); + fail(failureCause); + } + + public synchronized Optional scheduleTask( + InternalNode node, + int partition, + Multimap initialSplits, + Multimap noMoreSplitsForLifespan) + { + if (stateMachine.getState().isDone()) { + return Optional.empty(); + } + + checkArgument(!tasks.containsKey(partition), "A task for partition %s already exists", partition); + + OutputBuffers outputBuffers = outputBufferManagers.get(stage.getFragment().getId()).getOutputBuffers(); + + Optional optionalTask = stage.createTask( + node, + partition, + attempt, + bucketToPartition, + outputBuffers, + initialSplits, + ImmutableMultimap.of(), + ImmutableSet.of()); + + if (optionalTask.isEmpty()) { + return Optional.empty(); + } + + RemoteTask task = optionalTask.get(); + + tasks.put(partition, task); + + ImmutableMultimap.Builder exchangeSplits = ImmutableMultimap.builder(); + sourceTasks.forEach((fragmentId, sourceTask) -> { + TaskStatus status = sourceTask.getTaskStatus(); + if (status.getState() != TaskState.FINISHED) { + PlanNodeId planNodeId = exchangeSources.get(fragmentId).getId(); + exchangeSplits.put(planNodeId, createExchangeSplit(sourceTask, task)); + } + }); + + allTasks.add(task.getTaskId()); + + task.addSplits(exchangeSplits.build()); + noMoreSplitsForLifespan.forEach(task::noMoreSplits); + completeSources.forEach(task::noMoreSplits); + + task.addStateChangeListener(this::updateTaskStatus); + task.addStateChangeListener(this::updateCompletedDriverGroups); + + task.start(); + + taskLifecycleListener.taskCreated(stage.getFragment().getId(), task); + + // update output buffers + OutputBufferId outputBufferId = new OutputBufferId(task.getTaskId().getPartitionId()); + for (PlanFragmentId sourceFragment : exchangeSources.keySet()) { + OutputBufferManager outputBufferManager = outputBufferManagers.get(sourceFragment); + outputBufferManager.addOutputBuffer(outputBufferId); + for (RemoteTask sourceTask : sourceTasks.get(stage.getFragment().getId())) { + sourceTask.setOutputBuffers(outputBufferManager.getOutputBuffers()); + } + } + + return Optional.of(task); + } + + private synchronized void updateTaskStatus(TaskStatus taskStatus) + { + State stageState = stateMachine.getState(); + if (stageState.isDone()) { + return; + } + + TaskState taskState = taskStatus.getState(); + + switch (taskState) { + case FAILED: + RuntimeException failure = taskStatus.getFailures().stream() + .findFirst() + .map(this::rewriteTransportFailure) + .map(ExecutionFailureInfo::toException) + .orElse(new TrinoException(GENERIC_INTERNAL_ERROR, "A task failed for an unknown reason")); + fail(failure); + break; + case CANCELED: + // A task should only be in the canceled state if the STAGE is cancelled + fail(new TrinoException(GENERIC_INTERNAL_ERROR, "A task is in the CANCELED state but stage is " + stageState)); + break; + case ABORTED: + // A task should only be in the aborted state if the STAGE is done (ABORTED or FAILED) + fail(new TrinoException(GENERIC_INTERNAL_ERROR, "A task is in the ABORTED state but stage is " + stageState)); + break; + case FLUSHING: + flushingTasks.add(taskStatus.getTaskId()); + break; + case FINISHED: + finishedTasks.add(taskStatus.getTaskId()); + flushingTasks.remove(taskStatus.getTaskId()); + break; + default: + } + + if (stageState == SCHEDULED || stageState == RUNNING || stageState == FLUSHING) { + if (taskState == TaskState.RUNNING) { + stateMachine.transitionToRunning(); + } + if (isFlushing()) { + stateMachine.transitionToFlushing(); + } + if (finishedTasks.containsAll(allTasks)) { + stateMachine.transitionToFinished(); + } + } + } + + private synchronized void updateCompletedDriverGroups(TaskStatus taskStatus) + { + // Sets.difference returns a view. + // Once we add the difference into `completedDriverGroups`, the view will be empty. + // `completedLifespansChangeListeners.invoke` happens asynchronously. + // As a result, calling the listeners before updating `completedDriverGroups` doesn't make a difference. + // That's why a copy must be made here. + Set newlyCompletedDriverGroups = ImmutableSet.copyOf(Sets.difference(taskStatus.getCompletedDriverGroups(), this.completedDriverGroups)); + if (newlyCompletedDriverGroups.isEmpty()) { + return; + } + completedLifespansChangeListeners.invoke(newlyCompletedDriverGroups, executor); + // newlyCompletedDriverGroups is a view. + // Making changes to completedDriverGroups will change newlyCompletedDriverGroups. + completedDriverGroups.addAll(newlyCompletedDriverGroups); + } + + private ExecutionFailureInfo rewriteTransportFailure(ExecutionFailureInfo executionFailureInfo) + { + if (executionFailureInfo.getRemoteHost() == null || failureDetector.getState(executionFailureInfo.getRemoteHost()) != GONE) { + return executionFailureInfo; + } + + return new ExecutionFailureInfo( + executionFailureInfo.getType(), + executionFailureInfo.getMessage(), + executionFailureInfo.getCause(), + executionFailureInfo.getSuppressed(), + executionFailureInfo.getStack(), + executionFailureInfo.getErrorLocation(), + REMOTE_HOST_GONE.toErrorCode(), + executionFailureInfo.getRemoteHost()); + } + + public TaskLifecycleListener getTaskLifecycleListener() + { + return new TaskLifecycleListener() + { + @Override + public void taskCreated(PlanFragmentId fragmentId, RemoteTask task) + { + sourceTaskCreated(fragmentId, task); + } + + @Override + public void noMoreTasks(PlanFragmentId fragmentId) + { + noMoreSourceTasks(fragmentId); + } + }; + } + + private synchronized void sourceTaskCreated(PlanFragmentId fragmentId, RemoteTask sourceTask) + { + requireNonNull(fragmentId, "fragmentId is null"); + + RemoteSourceNode remoteSource = exchangeSources.get(fragmentId); + checkArgument(remoteSource != null, "Unknown remote source %s. Known sources are %s", fragmentId, exchangeSources.keySet()); + + sourceTasks.put(fragmentId, sourceTask); + + OutputBufferManager outputBufferManager = outputBufferManagers.get(fragmentId); + sourceTask.setOutputBuffers(outputBufferManager.getOutputBuffers()); + + for (RemoteTask destinationTask : getAllTasks()) { + destinationTask.addSplits(ImmutableMultimap.of(remoteSource.getId(), createExchangeSplit(sourceTask, destinationTask))); + } + } + + private synchronized void noMoreSourceTasks(PlanFragmentId fragmentId) + { + RemoteSourceNode remoteSource = exchangeSources.get(fragmentId); + checkArgument(remoteSource != null, "Unknown remote source %s. Known sources are %s", fragmentId, exchangeSources.keySet()); + + completeSourceFragments.add(fragmentId); + + // is the source now complete? + if (completeSourceFragments.containsAll(remoteSource.getSourceFragmentIds())) { + completeSources.add(remoteSource.getId()); + for (RemoteTask task : getAllTasks()) { + task.noMoreSplits(remoteSource.getId()); + } + } + } + + public List getAllTasks() + { + return ImmutableList.copyOf(tasks.values()); + } + + public List getTaskStatuses() + { + return getAllTasks().stream() + .map(RemoteTask::getTaskStatus) + .collect(toImmutableList()); + } + + public boolean isAnyTaskBlocked() + { + return getTaskStatuses().stream().anyMatch(TaskStatus::isOutputBufferOverutilized); + } + + public void recordGetSplitTime(long start) + { + stage.recordGetSplitTime(start); + } + + public StageId getStageId() + { + return stage.getStageId(); + } + + public PlanFragment getFragment() + { + return stage.getFragment(); + } + + public Optional getFailureCause() + { + return stateMachine.getFailureCause(); + } + + private static Split createExchangeSplit(RemoteTask sourceTask, RemoteTask destinationTask) + { + // Fetch the results from the buffer assigned to the task based on id + URI exchangeLocation = sourceTask.getTaskStatus().getSelf(); + URI splitLocation = uriBuilderFrom(exchangeLocation).appendPath("results").appendPath(String.valueOf(destinationTask.getTaskId().getPartitionId())).build(); + return new Split(REMOTE_CONNECTOR_ID, new RemoteSplit(sourceTask.getTaskId(), splitLocation), Lifespan.taskWide()); + } + + public enum State + { + /** + * Stage is planned but has not been scheduled yet. A stage will + * be in the planned state until, the dependencies of the stage + * have begun producing output. + */ + PLANNED(false, false), + /** + * Stage tasks are being scheduled on nodes. + */ + SCHEDULING(false, false), + /** + * All stage tasks have been scheduled, but splits are still being scheduled. + */ + SCHEDULING_SPLITS(false, false), + /** + * Stage has been scheduled on nodes and ready to execute, but all tasks are still queued. + */ + SCHEDULED(false, false), + /** + * Stage is running. + */ + RUNNING(false, false), + /** + * Stage has finished executing and output being consumed. + * In this state, at-least one of the tasks is flushing and the non-flushing tasks are finished + */ + FLUSHING(false, false), + /** + * Stage has finished executing and all output has been consumed. + */ + FINISHED(true, false), + /** + * Stage was canceled by a user. + */ + CANCELED(true, false), + /** + * Stage was aborted due to a failure in the query. The failure + * was not in this stage. + */ + ABORTED(true, true), + /** + * Stage execution failed. + */ + FAILED(true, true); + + private final boolean doneState; + private final boolean failureState; + + State(boolean doneState, boolean failureState) + { + checkArgument(!failureState || doneState, "%s is a non-done failure state", name()); + this.doneState = doneState; + this.failureState = failureState; + } + + /** + * Is this a terminal state. + */ + public boolean isDone() + { + return doneState; + } + + /** + * Is this a non-success terminal state. + */ + public boolean isFailure() + { + return failureState; + } + + public boolean canScheduleMoreTasks() + { + switch (this) { + case PLANNED: + case SCHEDULING: + // workers are still being added to the query + return true; + case SCHEDULING_SPLITS: + case SCHEDULED: + case RUNNING: + case FLUSHING: + case FINISHED: + case CANCELED: + // no more workers will be added to the query + return false; + case ABORTED: + case FAILED: + // DO NOT complete a FAILED or ABORTED stage. This will cause the + // stage above to finish normally, which will result in a query + // completing successfully when it should fail.. + return true; + } + throw new IllegalStateException("Unhandled state: " + this); + } + } + + private static class PipelinedStageStateMachine + { + private static final Set TERMINAL_STAGE_STATES = Stream.of(State.values()).filter(State::isDone).collect(toImmutableSet()); + + private final StageId stageId; + private final StateMachine state; + private final AtomicReference schedulingComplete = new AtomicReference<>(); + private final AtomicReference failureCause = new AtomicReference<>(); + + private PipelinedStageStateMachine(StageId stageId, Executor executor) + { + this.stageId = requireNonNull(stageId, "stageId is null"); + + state = new StateMachine<>("Pipelined stage execution " + stageId, executor, PLANNED, TERMINAL_STAGE_STATES); + state.addStateChangeListener(state -> log.debug("Pipelined stage execution %s is %s", stageId, state)); + } + + public State getState() + { + return state.get(); + } + + public boolean transitionToScheduling() + { + return state.compareAndSet(PLANNED, SCHEDULING); + } + + public boolean transitionToSchedulingSplits() + { + return state.setIf(SCHEDULING_SPLITS, currentState -> currentState == PLANNED || currentState == SCHEDULING); + } + + public boolean transitionToScheduled() + { + schedulingComplete.compareAndSet(null, DateTime.now()); + return state.setIf(SCHEDULED, currentState -> currentState == PLANNED || currentState == SCHEDULING || currentState == SCHEDULING_SPLITS); + } + + public boolean transitionToRunning() + { + return state.setIf(RUNNING, currentState -> currentState != RUNNING && currentState != FLUSHING && !currentState.isDone()); + } + + public boolean transitionToFlushing() + { + return state.setIf(FLUSHING, currentState -> currentState != FLUSHING && !currentState.isDone()); + } + + public boolean transitionToFinished() + { + return state.setIf(FINISHED, currentState -> !currentState.isDone()); + } + + public boolean transitionToCanceled() + { + return state.setIf(CANCELED, currentState -> !currentState.isDone()); + } + + public boolean transitionToAborted() + { + return state.setIf(ABORTED, currentState -> !currentState.isDone()); + } + + public boolean transitionToFailed(Throwable throwable) + { + requireNonNull(throwable, "throwable is null"); + + failureCause.compareAndSet(null, Failures.toFailure(throwable)); + boolean failed = state.setIf(FAILED, currentState -> !currentState.isDone()); + if (failed) { + log.error(throwable, "Pipelined stage execution for stage %s failed", stageId); + } + else { + log.debug(throwable, "Failure in pipelined stage execution for stage %s after finished", stageId); + } + return failed; + } + + public Optional getFailureCause() + { + return Optional.ofNullable(failureCause.get()); + } + + /** + * Listener is always notified asynchronously using a dedicated notification thread pool so, care should + * be taken to avoid leaking {@code this} when adding a listener in a constructor. Additionally, it is + * possible notifications are observed out of order due to the asynchronous execution. + */ + public void addStateChangeListener(StateChangeListener stateChangeListener) + { + state.addStateChangeListener(stateChangeListener); + } + } + + private static class ListenerManager + { + private final List> listeners = new ArrayList<>(); + private boolean frozen; + + public synchronized void addListener(Consumer listener) + { + checkState(!frozen, "Listeners have been invoked"); + listeners.add(listener); + } + + public synchronized void invoke(T payload, Executor executor) + { + frozen = true; + for (Consumer listener : listeners) { + executor.execute(() -> listener.accept(payload)); + } + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/ScaledOutputBufferManager.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/ScaledOutputBufferManager.java index 4af58c3c7990..e5e277264577 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/ScaledOutputBufferManager.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/ScaledOutputBufferManager.java @@ -18,55 +18,44 @@ import javax.annotation.concurrent.GuardedBy; -import java.util.List; -import java.util.function.Consumer; - import static io.trino.execution.buffer.OutputBuffers.BufferType.ARBITRARY; import static io.trino.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers; -import static java.util.Objects.requireNonNull; public class ScaledOutputBufferManager implements OutputBufferManager { - private final Consumer outputBufferTarget; - @GuardedBy("this") private OutputBuffers outputBuffers = createInitialEmptyOutputBuffers(ARBITRARY); - public ScaledOutputBufferManager(Consumer outputBufferTarget) - { - this.outputBufferTarget = requireNonNull(outputBufferTarget, "outputBufferTarget is null"); - outputBufferTarget.accept(outputBuffers); - } - @SuppressWarnings("ObjectEquality") @Override - public void addOutputBuffers(List newBuffers, boolean noMoreBuffers) + public synchronized void addOutputBuffer(OutputBufferId newBuffer) { - OutputBuffers newOutputBuffers; - synchronized (this) { - if (outputBuffers.isNoMoreBufferIds()) { - // a stage can move to a final state (e.g., failed) while scheduling, - // so ignore the new buffers - return; - } - - OutputBuffers originalOutputBuffers = outputBuffers; + if (outputBuffers.isNoMoreBufferIds()) { + // a stage can move to a final state (e.g., failed) while scheduling, so ignore + // the new buffers + return; + } - for (OutputBufferId newBuffer : newBuffers) { - outputBuffers = outputBuffers.withBuffer(newBuffer, newBuffer.getId()); - } + OutputBuffers newOutputBuffers = outputBuffers.withBuffer(newBuffer, newBuffer.getId()); - if (noMoreBuffers) { - outputBuffers = outputBuffers.withNoMoreBufferIds(); - } + // don't update if nothing changed + if (newOutputBuffers != outputBuffers) { + this.outputBuffers = newOutputBuffers; + } + } - // don't update if nothing changed - if (outputBuffers == originalOutputBuffers) { - return; - } - newOutputBuffers = this.outputBuffers; + @Override + public synchronized void noMoreBuffers() + { + if (!outputBuffers.isNoMoreBufferIds()) { + outputBuffers = outputBuffers.withNoMoreBufferIds(); } - outputBufferTarget.accept(newOutputBuffers); + } + + @Override + public synchronized OutputBuffers getOutputBuffers() + { + return outputBuffers; } } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/ScaledWriterScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/ScaledWriterScheduler.java index 2bfe6d619e8d..40c70be60fcf 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/ScaledWriterScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/ScaledWriterScheduler.java @@ -14,10 +14,10 @@ package io.trino.execution.scheduler; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMultimap; import com.google.common.util.concurrent.SettableFuture; import io.airlift.units.DataSize; import io.trino.execution.RemoteTask; -import io.trino.execution.SqlStageExecution; import io.trino.execution.TaskStatus; import io.trino.metadata.InternalNode; @@ -39,7 +39,7 @@ public class ScaledWriterScheduler implements StageScheduler { - private final SqlStageExecution stage; + private final PipelinedStageExecution stage; private final Supplier> sourceTasksProvider; private final Supplier> writerTasksProvider; private final NodeSelector nodeSelector; @@ -50,7 +50,7 @@ public class ScaledWriterScheduler private volatile SettableFuture future = SettableFuture.create(); public ScaledWriterScheduler( - SqlStageExecution stage, + PipelinedStageExecution stage, Supplier> sourceTasksProvider, Supplier> writerTasksProvider, NodeSelector nodeSelector, @@ -119,7 +119,7 @@ private List scheduleTasks(int count) ImmutableList.Builder tasks = ImmutableList.builder(); for (InternalNode node : nodes) { - Optional remoteTask = stage.scheduleTask(node, scheduledNodes.size()); + Optional remoteTask = stage.scheduleTask(node, scheduledNodes.size(), ImmutableMultimap.of(), ImmutableMultimap.of()); remoteTask.ifPresent(task -> { tasks.add(task); scheduledNodes.add(node); diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/SourcePartitionedScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/SourcePartitionedScheduler.java index 45d628b4764d..4014fda9a086 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/SourcePartitionedScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/SourcePartitionedScheduler.java @@ -22,7 +22,6 @@ import com.google.common.util.concurrent.SettableFuture; import io.trino.execution.Lifespan; import io.trino.execution.RemoteTask; -import io.trino.execution.SqlStageExecution; import io.trino.execution.TableExecuteContext; import io.trino.execution.TableExecuteContextManager; import io.trino.execution.scheduler.FixedSourcePartitionedScheduler.BucketedSplitPlacementPolicy; @@ -92,7 +91,7 @@ private enum State FINISHED } - private final SqlStageExecution stage; + private final PipelinedStageExecution stageExecution; private final SplitSource splitSource; private final SplitPlacementPolicy splitPlacementPolicy; private final int splitBatchSize; @@ -101,6 +100,8 @@ private enum State private final DynamicFilterService dynamicFilterService; private final TableExecuteContextManager tableExecuteContextManager; private final BooleanSupplier anySourceTaskBlocked; + private final PartitionIdAllocator partitionIdAllocator; + private final Map scheduledTasks; private final Map scheduleGroups = new HashMap<>(); private boolean noMoreScheduleGroups; @@ -109,7 +110,7 @@ private enum State private SettableFuture whenFinishedOrNewLifespanAdded = SettableFuture.create(); private SourcePartitionedScheduler( - SqlStageExecution stage, + PipelinedStageExecution stageExecution, PlanNodeId partitionedNode, SplitSource splitSource, SplitPlacementPolicy splitPlacementPolicy, @@ -117,19 +118,22 @@ private SourcePartitionedScheduler( boolean groupedExecution, DynamicFilterService dynamicFilterService, TableExecuteContextManager tableExecuteContextManager, - BooleanSupplier anySourceTaskBlocked) + BooleanSupplier anySourceTaskBlocked, + PartitionIdAllocator partitionIdAllocator, + Map scheduledTasks) { - this.stage = requireNonNull(stage, "stage is null"); - this.partitionedNode = requireNonNull(partitionedNode, "partitionedNode is null"); + this.stageExecution = requireNonNull(stageExecution, "stageExecution is null"); this.splitSource = requireNonNull(splitSource, "splitSource is null"); this.splitPlacementPolicy = requireNonNull(splitPlacementPolicy, "splitPlacementPolicy is null"); - this.dynamicFilterService = requireNonNull(dynamicFilterService, "dynamicFilterService is null"); - this.tableExecuteContextManager = requireNonNull(tableExecuteContextManager, "tableExecuteContextManager is null"); - this.anySourceTaskBlocked = requireNonNull(anySourceTaskBlocked, "anySourceTaskBlocked is null"); - checkArgument(splitBatchSize > 0, "splitBatchSize must be at least one"); this.splitBatchSize = splitBatchSize; + this.partitionedNode = requireNonNull(partitionedNode, "partitionedNode is null"); this.groupedExecution = groupedExecution; + this.dynamicFilterService = requireNonNull(dynamicFilterService, "dynamicFilterService is null"); + this.tableExecuteContextManager = requireNonNull(tableExecuteContextManager, "tableExecuteContextManager is null"); + this.anySourceTaskBlocked = requireNonNull(anySourceTaskBlocked, "anySourceTaskBlocked is null"); + this.partitionIdAllocator = requireNonNull(partitionIdAllocator, "partitionIdAllocator is null"); + this.scheduledTasks = requireNonNull(scheduledTasks, "scheduledTasks is null"); } @Override @@ -146,7 +150,7 @@ public PlanNodeId getPlanNodeId() * minimal management from the caller, which is ideal for use as a stage scheduler. */ public static StageScheduler newSourcePartitionedSchedulerAsStageScheduler( - SqlStageExecution stage, + PipelinedStageExecution stageExecution, PlanNodeId partitionedNode, SplitSource splitSource, SplitPlacementPolicy splitPlacementPolicy, @@ -156,7 +160,7 @@ public static StageScheduler newSourcePartitionedSchedulerAsStageScheduler( BooleanSupplier anySourceTaskBlocked) { SourcePartitionedScheduler sourcePartitionedScheduler = new SourcePartitionedScheduler( - stage, + stageExecution, partitionedNode, splitSource, splitPlacementPolicy, @@ -164,7 +168,9 @@ public static StageScheduler newSourcePartitionedSchedulerAsStageScheduler( false, dynamicFilterService, tableExecuteContextManager, - anySourceTaskBlocked); + anySourceTaskBlocked, + new PartitionIdAllocator(), + new HashMap<>()); sourcePartitionedScheduler.startLifespan(Lifespan.taskWide(), NOT_PARTITIONED); sourcePartitionedScheduler.noMoreLifespans(); @@ -198,7 +204,7 @@ public void close() * transitioning of the object will not work properly. */ public static SourceScheduler newSourcePartitionedSchedulerAsSourceScheduler( - SqlStageExecution stage, + PipelinedStageExecution stageExecution, PlanNodeId partitionedNode, SplitSource splitSource, SplitPlacementPolicy splitPlacementPolicy, @@ -206,10 +212,12 @@ public static SourceScheduler newSourcePartitionedSchedulerAsSourceScheduler( boolean groupedExecution, DynamicFilterService dynamicFilterService, TableExecuteContextManager tableExecuteContextManager, - BooleanSupplier anySourceTaskBlocked) + BooleanSupplier anySourceTaskBlocked, + PartitionIdAllocator partitionIdAllocator, + Map scheduledTasks) { return new SourcePartitionedScheduler( - stage, + stageExecution, partitionedNode, splitSource, splitPlacementPolicy, @@ -217,7 +225,9 @@ public static SourceScheduler newSourcePartitionedSchedulerAsSourceScheduler( groupedExecution, dynamicFilterService, tableExecuteContextManager, - anySourceTaskBlocked); + anySourceTaskBlocked, + partitionIdAllocator, + scheduledTasks); } @Override @@ -267,7 +277,7 @@ else if (pendingSplits.isEmpty()) { scheduleGroup.nextSplitBatchFuture = splitSource.getNextBatch(scheduleGroup.partitionHandle, lifespan, splitBatchSize - pendingSplits.size()); long start = System.nanoTime(); - addSuccessCallback(scheduleGroup.nextSplitBatchFuture, () -> stage.recordGetSplitTime(start)); + addSuccessCallback(scheduleGroup.nextSplitBatchFuture, () -> stageExecution.recordGetSplitTime(start)); } if (scheduleGroup.nextSplitBatchFuture.isDone()) { @@ -373,7 +383,7 @@ else if (pendingSplits.isEmpty()) { // Here we assume that we can get non-empty tableExecuteSplitsInfo only for queries which facilitate single split source. // TODO support grouped execution tableExecuteSplitsInfo.ifPresent(info -> { - TableExecuteContext tableExecuteContext = tableExecuteContextManager.getTableExecuteContextForQuery(stage.getStageId().getQueryId()); + TableExecuteContext tableExecuteContext = tableExecuteContextManager.getTableExecuteContextForQuery(stageExecution.getStageId().getQueryId()); tableExecuteContext.setSplitsInfo(info); }); @@ -397,17 +407,17 @@ else if (pendingSplits.isEmpty()) { } if (anyBlockedOnNextSplitBatch - && stage.getScheduledNodes().isEmpty() - && dynamicFilterService.isCollectingTaskNeeded(stage.getStageId().getQueryId(), stage.getFragment())) { + && scheduledTasks.isEmpty() + && dynamicFilterService.isCollectingTaskNeeded(stageExecution.getStageId().getQueryId(), stageExecution.getFragment())) { // schedule a task for collecting dynamic filters in case probe split generator is waiting for them - overallNewTasks.addAll(createTaskOnRandomNode()); + createTaskOnRandomNode().ifPresent(overallNewTasks::add); } boolean anySourceTaskBlocked = this.anySourceTaskBlocked.getAsBoolean(); if (anySourceTaskBlocked) { // Dynamic filters might not be collected due to build side source tasks being blocked on full buffer. // In such case probe split generation that is waiting for dynamic filters should be unblocked to prevent deadlock. - dynamicFilterService.unblockStageDynamicFilters(stage.getStageId().getQueryId(), stage.getFragment()); + dynamicFilterService.unblockStageDynamicFilters(stageExecution.getStageId().getQueryId(), stageExecution.getFragment()); } if (groupedExecution) { @@ -519,44 +529,56 @@ private Set assignSplits(Multimap splitAssignme if (noMoreSplitsNotification.containsKey(node)) { noMoreSplits.putAll(partitionedNode, noMoreSplitsNotification.get(node)); } - newTasks.addAll(stage.scheduleSplits( - node, - splits, - noMoreSplits.build())); + RemoteTask task = scheduledTasks.get(node); + if (task != null) { + task.addSplits(splits); + noMoreSplits.build().forEach(task::noMoreSplits); + } + else { + scheduleTask(node, splits, noMoreSplits.build()).ifPresent(newTasks::add); + } } return newTasks.build(); } - private Set createTaskOnRandomNode() + private Optional createTaskOnRandomNode() { - checkState(stage.getScheduledNodes().isEmpty(), "Stage task is already scheduled on node"); + checkState(scheduledTasks.isEmpty(), "Stage task is already scheduled on node"); List allNodes = splitPlacementPolicy.allNodes(); checkState(allNodes.size() > 0, "No nodes available"); InternalNode node = allNodes.get(ThreadLocalRandom.current().nextInt(0, allNodes.size())); - return stage.scheduleSplits(node, ImmutableMultimap.of(), ImmutableMultimap.of()); + return scheduleTask(node, ImmutableMultimap.of(), ImmutableMultimap.of()); } private Set finalizeTaskCreationIfNecessary() { // only lock down tasks if there is a sub stage that could block waiting for this stage to create all tasks - if (stage.getFragment().isLeaf()) { + if (stageExecution.getFragment().isLeaf()) { return ImmutableSet.of(); } splitPlacementPolicy.lockDownNodes(); - Set scheduledNodes = stage.getScheduledNodes(); Set newTasks = splitPlacementPolicy.allNodes().stream() - .filter(node -> !scheduledNodes.contains(node)) - .flatMap(node -> stage.scheduleSplits(node, ImmutableMultimap.of(), ImmutableMultimap.of()).stream()) + .filter(node -> !scheduledTasks.containsKey(node)) + .map(node -> scheduleTask(node, ImmutableMultimap.of(), ImmutableMultimap.of())) + .filter(Optional::isPresent) + .map(Optional::get) .collect(toImmutableSet()); // notify listeners that we have scheduled all tasks so they can set no more buffers or exchange splits - stage.transitionToSchedulingSplits(); + stageExecution.transitionToSchedulingSplits(); return newTasks; } + private Optional scheduleTask(InternalNode node, Multimap initialSplits, Multimap noMoreSplitsForLifespan) + { + Optional remoteTask = stageExecution.scheduleTask(node, partitionIdAllocator.getNextId(), initialSplits, noMoreSplitsForLifespan); + remoteTask.ifPresent(task -> scheduledTasks.put(node, task)); + return remoteTask; + } + private static class ScheduleGroup { public final ConnectorPartitionHandle partitionHandle; diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java index e800b98fbe73..a0742db9f03d 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java @@ -13,45 +13,64 @@ */ package io.trino.execution.scheduler; +import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Iterables; +import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.Sets; +import com.google.common.graph.Traverser; import com.google.common.primitives.Ints; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; import io.airlift.concurrent.SetThreadName; +import io.airlift.log.Logger; import io.airlift.stats.TimeStat; import io.airlift.units.Duration; import io.trino.Session; import io.trino.connector.CatalogName; import io.trino.execution.BasicStageStats; +import io.trino.execution.ExecutionFailureInfo; import io.trino.execution.NodeTaskMap; import io.trino.execution.QueryState; import io.trino.execution.QueryStateMachine; import io.trino.execution.RemoteTask; import io.trino.execution.RemoteTaskFactory; -import io.trino.execution.SqlStageExecution; +import io.trino.execution.SqlStage; import io.trino.execution.StageId; import io.trino.execution.StageInfo; -import io.trino.execution.StageState; +import io.trino.execution.StateMachine; +import io.trino.execution.StateMachine.StateChangeListener; import io.trino.execution.TableExecuteContextManager; +import io.trino.execution.TableInfo; +import io.trino.execution.TaskFailureListener; +import io.trino.execution.TaskId; +import io.trino.execution.TaskManager; import io.trino.execution.TaskStatus; -import io.trino.execution.buffer.OutputBuffers; -import io.trino.execution.buffer.OutputBuffers.OutputBufferId; import io.trino.failuredetector.FailureDetector; import io.trino.metadata.InternalNode; +import io.trino.metadata.Metadata; +import io.trino.metadata.TableProperties; +import io.trino.metadata.TableSchema; +import io.trino.operator.RetryPolicy; import io.trino.server.DynamicFilterService; +import io.trino.spi.ErrorCode; +import io.trino.spi.QueryId; import io.trino.spi.TrinoException; import io.trino.spi.connector.ConnectorPartitionHandle; import io.trino.split.SplitSource; import io.trino.sql.planner.NodePartitionMap; import io.trino.sql.planner.NodePartitioningManager; import io.trino.sql.planner.PartitioningHandle; -import io.trino.sql.planner.StageExecutionPlan; +import io.trino.sql.planner.PlanFragment; +import io.trino.sql.planner.SplitSourceFactory; +import io.trino.sql.planner.SubPlan; import io.trino.sql.planner.plan.PlanFragmentId; +import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.PlanNodeId; +import io.trino.sql.planner.plan.RemoteSourceNode; +import io.trino.sql.planner.plan.TableScanNode; + +import javax.annotation.concurrent.GuardedBy; import java.net.URI; import java.util.ArrayList; @@ -64,13 +83,17 @@ import java.util.Map.Entry; import java.util.Optional; import java.util.Set; +import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import java.util.function.Predicate; import java.util.function.Supplier; +import java.util.stream.Stream; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; @@ -78,658 +101,1670 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Iterables.getFirst; +import static com.google.common.collect.Iterables.getLast; +import static com.google.common.collect.Iterables.getOnlyElement; +import static com.google.common.collect.Sets.newConcurrentHashSet; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static io.airlift.concurrent.MoreFutures.tryGetFutureValue; import static io.airlift.concurrent.MoreFutures.whenAnyComplete; import static io.airlift.http.client.HttpUriBuilder.uriBuilderFrom; import static io.trino.SystemSessionProperties.getConcurrentLifespansPerNode; +import static io.trino.SystemSessionProperties.getRetryAttempts; +import static io.trino.SystemSessionProperties.getRetryInitialDelay; +import static io.trino.SystemSessionProperties.getRetryMaxDelay; +import static io.trino.SystemSessionProperties.getRetryPolicy; import static io.trino.SystemSessionProperties.getWriterMinSize; import static io.trino.connector.CatalogName.isInternalSystemConnector; import static io.trino.execution.BasicStageStats.aggregateBasicStageStats; -import static io.trino.execution.SqlStageExecution.createSqlStageExecution; -import static io.trino.execution.StageState.ABORTED; -import static io.trino.execution.StageState.CANCELED; -import static io.trino.execution.StageState.FAILED; -import static io.trino.execution.StageState.FINISHED; -import static io.trino.execution.StageState.FLUSHING; -import static io.trino.execution.StageState.RUNNING; -import static io.trino.execution.StageState.SCHEDULED; +import static io.trino.execution.SqlStage.createSqlStage; +import static io.trino.execution.scheduler.PipelinedStageExecution.State.ABORTED; +import static io.trino.execution.scheduler.PipelinedStageExecution.State.CANCELED; +import static io.trino.execution.scheduler.PipelinedStageExecution.State.FAILED; +import static io.trino.execution.scheduler.PipelinedStageExecution.State.FINISHED; +import static io.trino.execution.scheduler.PipelinedStageExecution.State.FLUSHING; +import static io.trino.execution.scheduler.PipelinedStageExecution.State.RUNNING; +import static io.trino.execution.scheduler.PipelinedStageExecution.State.SCHEDULED; +import static io.trino.execution.scheduler.PipelinedStageExecution.createPipelinedStageExecution; import static io.trino.execution.scheduler.SourcePartitionedScheduler.newSourcePartitionedSchedulerAsStageScheduler; +import static io.trino.spi.ErrorType.EXTERNAL; +import static io.trino.spi.ErrorType.INTERNAL_ERROR; +import static io.trino.spi.StandardErrorCode.CLUSTER_OUT_OF_MEMORY; import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static io.trino.spi.StandardErrorCode.NO_NODES_AVAILABLE; +import static io.trino.spi.StandardErrorCode.REMOTE_TASK_FAILED; import static io.trino.spi.connector.NotPartitionedPartitionHandle.NOT_PARTITIONED; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION; +import static io.trino.sql.planner.optimizations.PlanNodeSearcher.searchFrom; import static io.trino.sql.planner.plan.ExchangeNode.Type.REPLICATE; import static io.trino.util.Failures.checkCondition; +import static io.trino.util.Failures.toFailure; +import static java.lang.Integer.parseInt; +import static java.lang.Math.min; +import static java.lang.Math.pow; import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.MINUTES; import static java.util.concurrent.TimeUnit.SECONDS; import static java.util.function.Function.identity; import static java.util.stream.Collectors.toCollection; public class SqlQueryScheduler { + private static final Logger log = Logger.get(SqlQueryScheduler.class); + private final QueryStateMachine queryStateMachine; - private final ExecutionPolicy executionPolicy; - private final Map stages; + private final NodePartitioningManager nodePartitioningManager; + private final NodeScheduler nodeScheduler; + private final int splitBatchSize; private final ExecutorService executor; - private final StageId rootStageId; - private final Map stageSchedulers; - private final Map stageLinkages; + private final ScheduledExecutorService schedulerExecutor; + private final FailureDetector failureDetector; + private final ExecutionPolicy executionPolicy; private final SplitSchedulerStats schedulerStats; - private final boolean summarizeTaskInfo; private final DynamicFilterService dynamicFilterService; private final TableExecuteContextManager tableExecuteContextManager; - private final AtomicBoolean started = new AtomicBoolean(); + private final SplitSourceFactory splitSourceFactory; - public static SqlQueryScheduler createSqlQueryScheduler( - QueryStateMachine queryStateMachine, - StageExecutionPlan plan, - NodePartitioningManager nodePartitioningManager, - NodeScheduler nodeScheduler, - RemoteTaskFactory remoteTaskFactory, - Session session, - boolean summarizeTaskInfo, - int splitBatchSize, - ExecutorService queryExecutor, - ScheduledExecutorService schedulerExecutor, - FailureDetector failureDetector, - OutputBuffers rootOutputBuffers, - NodeTaskMap nodeTaskMap, - ExecutionPolicy executionPolicy, - SplitSchedulerStats schedulerStats, - DynamicFilterService dynamicFilterService, - TableExecuteContextManager tableExecuteContextManager) - { - SqlQueryScheduler sqlQueryScheduler = new SqlQueryScheduler( - queryStateMachine, - plan, - nodePartitioningManager, - nodeScheduler, - remoteTaskFactory, - session, - summarizeTaskInfo, - splitBatchSize, - queryExecutor, - schedulerExecutor, - failureDetector, - rootOutputBuffers, - nodeTaskMap, - executionPolicy, - schedulerStats, - dynamicFilterService, - tableExecuteContextManager); - sqlQueryScheduler.initialize(); - return sqlQueryScheduler; - } + private final StageManager stageManager; + private final CoordinatorStagesScheduler coordinatorStagesScheduler; - private SqlQueryScheduler( + private final RetryPolicy retryPolicy; + private final int maxRetryAttempts; + private final AtomicInteger currentAttempt = new AtomicInteger(); + private final Duration retryInitialDelay; + private final Duration retryMaxDelay; + + @GuardedBy("this") + private boolean started; + + @GuardedBy("this") + private final AtomicReference distributedStagesScheduler = new AtomicReference<>(); + @GuardedBy("this") + private Future distributedStagesSchedulingTask; + + public SqlQueryScheduler( QueryStateMachine queryStateMachine, - StageExecutionPlan plan, + SubPlan plan, NodePartitioningManager nodePartitioningManager, NodeScheduler nodeScheduler, RemoteTaskFactory remoteTaskFactory, - Session session, boolean summarizeTaskInfo, int splitBatchSize, ExecutorService queryExecutor, ScheduledExecutorService schedulerExecutor, FailureDetector failureDetector, - OutputBuffers rootOutputBuffers, NodeTaskMap nodeTaskMap, ExecutionPolicy executionPolicy, SplitSchedulerStats schedulerStats, DynamicFilterService dynamicFilterService, - TableExecuteContextManager tableExecuteContextManager) + TableExecuteContextManager tableExecuteContextManager, + Metadata metadata, + SplitSourceFactory splitSourceFactory, + TaskManager coordinatorTaskManager) { this.queryStateMachine = requireNonNull(queryStateMachine, "queryStateMachine is null"); + this.nodePartitioningManager = requireNonNull(nodePartitioningManager, "nodePartitioningManager is null"); + this.nodeScheduler = requireNonNull(nodeScheduler, "nodeScheduler is null"); + this.splitBatchSize = splitBatchSize; + this.executor = requireNonNull(queryExecutor, "queryExecutor is null"); + this.schedulerExecutor = requireNonNull(schedulerExecutor, "schedulerExecutor is null"); + this.failureDetector = requireNonNull(failureDetector, "failureDetector is null"); this.executionPolicy = requireNonNull(executionPolicy, "executionPolicy is null"); this.schedulerStats = requireNonNull(schedulerStats, "schedulerStats is null"); - this.summarizeTaskInfo = summarizeTaskInfo; this.dynamicFilterService = requireNonNull(dynamicFilterService, "dynamicFilterService is null"); this.tableExecuteContextManager = requireNonNull(tableExecuteContextManager, "tableExecuteContextManager is null"); + this.splitSourceFactory = requireNonNull(splitSourceFactory, "splitSourceFactory is null"); - // todo come up with a better way to build this, or eliminate this map - ImmutableMap.Builder stageSchedulers = ImmutableMap.builder(); - ImmutableMap.Builder stageLinkages = ImmutableMap.builder(); - - // Only fetch a distribution once per query to assure all stages see the same machine assignments - Map partitioningCache = new HashMap<>(); - - OutputBufferId rootBufferId = Iterables.getOnlyElement(rootOutputBuffers.getBuffers().keySet()); - List stages = createStages( - (fragmentId, tasks, noMoreExchangeLocations) -> updateQueryOutputLocations(queryStateMachine, rootBufferId, tasks, noMoreExchangeLocations), - new AtomicInteger(), - plan.withBucketToPartition(Optional.of(new int[1])), - nodeScheduler, + stageManager = StageManager.create( + queryStateMachine, + queryStateMachine.getSession(), + metadata, remoteTaskFactory, - session, - splitBatchSize, - partitioningHandle -> partitioningCache.computeIfAbsent(partitioningHandle, handle -> nodePartitioningManager.getNodePartitioningMap(session, handle)), - nodePartitioningManager, + nodeTaskMap, queryExecutor, - schedulerExecutor, + schedulerStats, + plan, + summarizeTaskInfo); + + coordinatorStagesScheduler = CoordinatorStagesScheduler.create( + queryStateMachine, + nodeScheduler, + stageManager, failureDetector, - nodeTaskMap, - stageSchedulers, - stageLinkages); + schedulerExecutor, + distributedStagesScheduler, + coordinatorTaskManager); + + retryPolicy = getRetryPolicy(queryStateMachine.getSession()); + maxRetryAttempts = getRetryAttempts(queryStateMachine.getSession()); + retryInitialDelay = getRetryInitialDelay(queryStateMachine.getSession()); + retryMaxDelay = getRetryMaxDelay(queryStateMachine.getSession()); + } + + public synchronized void start() + { + if (started) { + return; + } + started = true; + + if (queryStateMachine.isDone()) { + return; + } + + // when query is done or any time a stage completes, attempt to transition query to "final query info ready" + queryStateMachine.addStateChangeListener(state -> { + if (!state.isDone()) { + return; + } - SqlStageExecution rootStage = stages.get(0); - rootStage.setOutputBuffers(rootOutputBuffers); - this.rootStageId = rootStage.getStageId(); + DistributedStagesScheduler distributedStagesScheduler; + // synchronize to wait on distributed scheduler creation if it is currently in process + synchronized (this) { + distributedStagesScheduler = this.distributedStagesScheduler.get(); + } + + if (state == QueryState.FINISHED) { + coordinatorStagesScheduler.cancel(); + if (distributedStagesScheduler != null) { + distributedStagesScheduler.cancel(); + } + stageManager.finish(); + } + else if (state == QueryState.FAILED) { + coordinatorStagesScheduler.abort(); + if (distributedStagesScheduler != null) { + distributedStagesScheduler.abort(); + } + stageManager.abort(); + } - this.stages = stages.stream() - .collect(toImmutableMap(SqlStageExecution::getStageId, identity())); + queryStateMachine.updateQueryInfo(Optional.ofNullable(getStageInfo())); + }); - this.stageSchedulers = stageSchedulers.build(); - this.stageLinkages = stageLinkages.build(); + coordinatorStagesScheduler.schedule(); - this.executor = queryExecutor; + Optional distributedStagesScheduler = createDistributedStagesScheduler(currentAttempt.get()); + distributedStagesScheduler.ifPresent(scheduler -> distributedStagesSchedulingTask = executor.submit(scheduler::schedule, null)); } - // this is a separate method to ensure that the `this` reference is not leaked during construction - private void initialize() + private synchronized Optional createDistributedStagesScheduler(int attempt) { - SqlStageExecution rootStage = stages.get(rootStageId); - rootStage.addStateChangeListener(state -> { - if (state == FINISHED) { - queryStateMachine.transitionToFinishing(); + if (queryStateMachine.isDone()) { + return Optional.empty(); + } + DistributedStagesScheduler distributedStagesScheduler = PipelinedDistributedStagesScheduler.create( + queryStateMachine, + schedulerStats, + nodeScheduler, + nodePartitioningManager, + stageManager, + coordinatorStagesScheduler, + executionPolicy, + failureDetector, + schedulerExecutor, + splitSourceFactory, + splitBatchSize, + dynamicFilterService, + tableExecuteContextManager, + retryPolicy, + attempt); + this.distributedStagesScheduler.set(distributedStagesScheduler); + distributedStagesScheduler.addStateChangeListener(state -> { + if (queryStateMachine.getQueryState() == QueryState.STARTING && state.isRunningOrDone()) { + queryStateMachine.transitionToRunning(); } - else if (state == CANCELED) { - // output stage was canceled - queryStateMachine.transitionToCanceled(); + + if (state.isDone() && !state.isFailure()) { + stageManager.getDistributedStagesInTopologicalOrder().forEach(stage -> stageManager.get(stage.getStageId()).finish()); } - }); - for (SqlStageExecution stage : stages.values()) { - stage.addStateChangeListener(state -> { - if (queryStateMachine.isDone()) { - return; + if (stageManager.getCoordinatorStagesInTopologicalOrder().isEmpty()) { + // if there are no coordinator stages (e.g., simple select query) and the distributed stages are finished, do the query transitioning + // otherwise defer query transitioning to the coordinator stages + if (state == DistributedStagesSchedulerState.FINISHED) { + queryStateMachine.transitionToFinishing(); } - if (state == FAILED) { - queryStateMachine.transitionToFailed(stage.getStageInfo().getFailureCause().toException()); + else if (state == DistributedStagesSchedulerState.CANCELED) { + // output stage was canceled + queryStateMachine.transitionToCanceled(); } - else if (state == ABORTED) { - // this should never happen, since abort can only be triggered in query clean up after the query is finished - queryStateMachine.transitionToFailed(new TrinoException(GENERIC_INTERNAL_ERROR, "Query stage was aborted")); + } + + if (state == DistributedStagesSchedulerState.FAILED) { + StageFailureInfo stageFailureInfo = distributedStagesScheduler.getFailureCause() + .orElseGet(() -> new StageFailureInfo(toFailure(new VerifyException("distributedStagesScheduler failed but failure cause is not present")), Optional.empty())); + ErrorCode errorCode = stageFailureInfo.getFailureInfo().getErrorCode(); + if (shouldRetry(errorCode)) { + long delayInMillis = min(retryInitialDelay.toMillis() * ((long) pow(2, currentAttempt.get())), retryMaxDelay.toMillis()); + currentAttempt.incrementAndGet(); + scheduleRetryWithDelay(delayInMillis); } - else if (queryStateMachine.getQueryState() == QueryState.STARTING) { - // if the stage has at least one task, we are running - if (stage.hasTasks()) { - queryStateMachine.transitionToRunning(); - } + else { + stageManager.getDistributedStagesInTopologicalOrder().forEach(stage -> { + if (stageFailureInfo.getFailedStageId().isPresent() && stageFailureInfo.getFailedStageId().get().equals(stage.getStageId())) { + stage.fail(stageFailureInfo.getFailureInfo().toException()); + } + else { + stage.abort(); + } + }); + queryStateMachine.transitionToFailed(stageFailureInfo.getFailureInfo().toException()); } - }); - } - - // when query is done or any time a stage completes, attempt to transition query to "final query info ready" - queryStateMachine.addStateChangeListener(newState -> { - if (newState.isDone()) { - queryStateMachine.updateQueryInfo(Optional.ofNullable(getStageInfo())); } }); - for (SqlStageExecution stage : stages.values()) { - stage.addFinalStageInfoListener(status -> queryStateMachine.updateQueryInfo(Optional.ofNullable(getStageInfo()))); - } + return Optional.of(distributedStagesScheduler); } - private static void updateQueryOutputLocations(QueryStateMachine queryStateMachine, OutputBufferId rootBufferId, Set tasks, boolean noMoreExchangeLocations) + private boolean shouldRetry(ErrorCode errorCode) { - Set bufferLocations = tasks.stream() - .map(task -> task.getTaskStatus().getSelf()) - .map(location -> uriBuilderFrom(location).appendPath("results").appendPath(rootBufferId.toString()).build()) - .collect(toImmutableSet()); - queryStateMachine.updateOutputLocations(bufferLocations, noMoreExchangeLocations); + return retryPolicy == RetryPolicy.QUERY && currentAttempt.get() < maxRetryAttempts && isRetryableErrorCode(errorCode); } - private List createStages( - ExchangeLocationsConsumer parent, - AtomicInteger nextStageId, - StageExecutionPlan plan, - NodeScheduler nodeScheduler, - RemoteTaskFactory remoteTaskFactory, - Session session, - int splitBatchSize, - Function partitioningCache, - NodePartitioningManager nodePartitioningManager, - ExecutorService queryExecutor, - ScheduledExecutorService schedulerExecutor, - FailureDetector failureDetector, - NodeTaskMap nodeTaskMap, - ImmutableMap.Builder stageSchedulers, - ImmutableMap.Builder stageLinkages) + private static boolean isRetryableErrorCode(ErrorCode errorCode) { - ImmutableList.Builder stages = ImmutableList.builder(); - - StageId stageId = new StageId(queryStateMachine.getQueryId(), nextStageId.getAndIncrement()); - SqlStageExecution stage = createSqlStageExecution( - stageId, - plan.getFragment(), - plan.getTables(), - remoteTaskFactory, - session, - summarizeTaskInfo, - nodeTaskMap, - queryExecutor, - failureDetector, - dynamicFilterService, - schedulerStats); - stages.add(stage); - - // function to create child stages recursively by supplying the bucket partitioning (according to parent's partitioning) - Function, Set> createChildStages = bucketToPartition -> { - ImmutableSet.Builder childStagesBuilder = ImmutableSet.builder(); - for (StageExecutionPlan subStagePlan : plan.getSubStages()) { - List subTree = createStages( - stage::addExchangeLocations, - nextStageId, - subStagePlan.withBucketToPartition(bucketToPartition), - nodeScheduler, - remoteTaskFactory, - session, - splitBatchSize, - partitioningCache, - nodePartitioningManager, - queryExecutor, - schedulerExecutor, - failureDetector, - nodeTaskMap, - stageSchedulers, - stageLinkages); - stages.addAll(subTree); - - SqlStageExecution childStage = subTree.get(0); - childStagesBuilder.add(childStage); - } - return childStagesBuilder.build(); - }; - - Set childStages; - PartitioningHandle partitioningHandle = plan.getFragment().getPartitioning(); - if (partitioningHandle.equals(SOURCE_DISTRIBUTION)) { - // nodes are selected dynamically based on the constraints of the splits and the system load - Entry entry = Iterables.getOnlyElement(plan.getSplitSources().entrySet()); - PlanNodeId planNodeId = entry.getKey(); - SplitSource splitSource = entry.getValue(); - Optional catalogName = Optional.of(splitSource.getCatalogName()) - .filter(catalog -> !isInternalSystemConnector(catalog)); - NodeSelector nodeSelector = nodeScheduler.createNodeSelector(session, catalogName); - SplitPlacementPolicy placementPolicy = new DynamicSplitPlacementPolicy(nodeSelector, stage::getAllTasks); - - checkArgument(!plan.getFragment().getStageExecutionDescriptor().isStageGroupedExecution()); - - childStages = createChildStages.apply(Optional.of(new int[1])); - stageSchedulers.put(stageId, newSourcePartitionedSchedulerAsStageScheduler( - stage, - planNodeId, - splitSource, - placementPolicy, - splitBatchSize, - dynamicFilterService, - tableExecuteContextManager, - () -> childStages.stream().anyMatch(SqlStageExecution::isAnyTaskBlocked))); - } - else if (partitioningHandle.equals(SCALED_WRITER_DISTRIBUTION)) { - childStages = createChildStages.apply(Optional.of(new int[1])); - Supplier> sourceTasksProvider = () -> childStages.stream() - .map(SqlStageExecution::getTaskStatuses) - .flatMap(List::stream) - .collect(toImmutableList()); - Supplier> writerTasksProvider = stage::getTaskStatuses; - - ScaledWriterScheduler scheduler = new ScaledWriterScheduler( - stage, - sourceTasksProvider, - writerTasksProvider, - nodeScheduler.createNodeSelector(session, Optional.empty()), - schedulerExecutor, - getWriterMinSize(session)); - whenAllStages(childStages, StageState::isDone) - .addListener(scheduler::finish, directExecutor()); - stageSchedulers.put(stageId, scheduler); - } - else { - Optional bucketToPartition; - Map splitSources = plan.getSplitSources(); - if (!splitSources.isEmpty()) { - // contains local source - List schedulingOrder = plan.getFragment().getPartitionedSources(); - Optional catalogName = partitioningHandle.getConnectorId(); - checkArgument(catalogName.isPresent(), "No connector ID for partitioning handle: %s", partitioningHandle); - List connectorPartitionHandles; - boolean groupedExecutionForStage = plan.getFragment().getStageExecutionDescriptor().isStageGroupedExecution(); - if (groupedExecutionForStage) { - connectorPartitionHandles = nodePartitioningManager.listPartitionHandles(session, partitioningHandle); - checkState(!ImmutableList.of(NOT_PARTITIONED).equals(connectorPartitionHandles)); - } - else { - connectorPartitionHandles = ImmutableList.of(NOT_PARTITIONED); - } + return errorCode == null + || errorCode.getType() == INTERNAL_ERROR + || errorCode.getType() == EXTERNAL + || errorCode.getCode() == CLUSTER_OUT_OF_MEMORY.toErrorCode().getCode(); + } - BucketNodeMap bucketNodeMap; - List stageNodeList; - if (plan.getFragment().getRemoteSourceNodes().stream().allMatch(node -> node.getExchangeType() == REPLICATE)) { - // no remote source - boolean dynamicLifespanSchedule = plan.getFragment().getStageExecutionDescriptor().isDynamicLifespanSchedule(); - bucketNodeMap = nodePartitioningManager.getBucketNodeMap(session, partitioningHandle, dynamicLifespanSchedule); + private void scheduleRetryWithDelay(long delayInMillis) + { + try { + schedulerExecutor.schedule(this::scheduleRetry, delayInMillis, MILLISECONDS); + } + catch (Throwable t) { + queryStateMachine.transitionToFailed(t); + } + } - // verify execution is consistent with planner's decision on dynamic lifespan schedule - verify(bucketNodeMap.isDynamic() == dynamicLifespanSchedule); + private synchronized void scheduleRetry() + { + try { + checkState(distributedStagesSchedulingTask != null, "schedulingTask is expected to be set"); - stageNodeList = new ArrayList<>(nodeScheduler.createNodeSelector(session, catalogName).allNodes()); - Collections.shuffle(stageNodeList); - bucketToPartition = Optional.empty(); - } - else { - // cannot use dynamic lifespan schedule - verify(!plan.getFragment().getStageExecutionDescriptor().isDynamicLifespanSchedule()); + // give current scheduler some time to terminate, usually it is expected to be done right away + distributedStagesSchedulingTask.get(5, MINUTES); - // remote source requires nodePartitionMap - NodePartitionMap nodePartitionMap = partitioningCache.apply(plan.getFragment().getPartitioning()); - if (groupedExecutionForStage) { - checkState(connectorPartitionHandles.size() == nodePartitionMap.getBucketToPartition().length); - } - stageNodeList = nodePartitionMap.getPartitionToNode(); - bucketNodeMap = nodePartitionMap.asBucketNodeMap(); - bucketToPartition = Optional.of(nodePartitionMap.getBucketToPartition()); - } + Optional distributedStagesScheduler = createDistributedStagesScheduler(currentAttempt.get()); + distributedStagesScheduler.ifPresent(scheduler -> distributedStagesSchedulingTask = executor.submit(scheduler::schedule, null)); + } + catch (Throwable t) { + queryStateMachine.transitionToFailed(t); + } + } - stageSchedulers.put(stageId, new FixedSourcePartitionedScheduler( - stage, - splitSources, - plan.getFragment().getStageExecutionDescriptor(), - schedulingOrder, - stageNodeList, - bucketNodeMap, - splitBatchSize, - getConcurrentLifespansPerNode(session), - nodeScheduler.createNodeSelector(session, catalogName), - connectorPartitionHandles, - dynamicFilterService, - tableExecuteContextManager)); - } - else { - // all sources are remote - NodePartitionMap nodePartitionMap = partitioningCache.apply(plan.getFragment().getPartitioning()); - List partitionToNode = nodePartitionMap.getPartitionToNode(); - // todo this should asynchronously wait a standard timeout period before failing - checkCondition(!partitionToNode.isEmpty(), NO_NODES_AVAILABLE, "No worker nodes available"); - stageSchedulers.put(stageId, new FixedCountScheduler(stage, partitionToNode)); - bucketToPartition = Optional.of(nodePartitionMap.getBucketToPartition()); + public synchronized void cancelStage(StageId stageId) + { + try (SetThreadName ignored = new SetThreadName("Query-%s", queryStateMachine.getQueryId())) { + coordinatorStagesScheduler.cancelStage(stageId); + DistributedStagesScheduler distributedStagesScheduler = this.distributedStagesScheduler.get(); + if (distributedStagesScheduler != null) { + distributedStagesScheduler.cancelStage(stageId); } - childStages = createChildStages.apply(bucketToPartition); } + } - stage.addStateChangeListener(newState -> { - if (newState == FLUSHING || newState.isDone()) { - childStages.forEach(SqlStageExecution::cancel); + public synchronized void abort() + { + try (SetThreadName ignored = new SetThreadName("Query-%s", queryStateMachine.getQueryId())) { + coordinatorStagesScheduler.abort(); + DistributedStagesScheduler distributedStagesScheduler = this.distributedStagesScheduler.get(); + if (distributedStagesScheduler != null) { + distributedStagesScheduler.abort(); } - }); - - stageLinkages.put(stageId, new StageLinkage(plan.getFragment().getId(), parent, childStages)); - - return stages.build(); + } } public BasicStageStats getBasicStageStats() { - List stageStats = stages.values().stream() - .map(SqlStageExecution::getBasicStageStats) - .collect(toImmutableList()); - - return aggregateBasicStageStats(stageStats); + return stageManager.getBasicStageStats(); } public StageInfo getStageInfo() { - Map stageInfos = stages.values().stream() - .map(SqlStageExecution::getStageInfo) - .collect(toImmutableMap(StageInfo::getStageId, identity())); - - return buildStageInfo(rootStageId, stageInfos); - } - - private StageInfo buildStageInfo(StageId stageId, Map stageInfos) - { - StageInfo parent = stageInfos.get(stageId); - checkArgument(parent != null, "No stageInfo for %s", parent); - List childStages = stageLinkages.get(stageId).getChildStageIds().stream() - .map(childStageId -> buildStageInfo(childStageId, stageInfos)) - .collect(toImmutableList()); - if (childStages.isEmpty()) { - return parent; - } - return new StageInfo( - parent.getStageId(), - parent.getState(), - parent.getPlan(), - parent.getTypes(), - parent.getStageStats(), - parent.getTasks(), - childStages, - parent.getTables(), - parent.getFailureCause()); + return stageManager.getStageInfo(); } public long getUserMemoryReservation() { - return stages.values().stream() - .mapToLong(SqlStageExecution::getUserMemoryReservation) - .sum(); + return stageManager.getUserMemoryReservation(); } public long getTotalMemoryReservation() { - return stages.values().stream() - .mapToLong(SqlStageExecution::getTotalMemoryReservation) - .sum(); + return stageManager.getTotalMemoryReservation(); } public Duration getTotalCpuTime() { - long millis = stages.values().stream() - .mapToLong(stage -> stage.getTotalCpuTime().toMillis()) - .sum(); - return new Duration(millis, MILLISECONDS); + return stageManager.getTotalCpuTime(); } - public void start() + private static class StageManager { - if (started.compareAndSet(false, true)) { - executor.submit(this::schedule); - } - } + private final QueryStateMachine queryStateMachine; + private final Map stages; + private final List coordinatorStagesInTopologicalOrder; + private final List distributedStagesInTopologicalOrder; + private final StageId rootStageId; + private final Map> children; + private final Map parents; - private void schedule() - { - try (SetThreadName ignored = new SetThreadName("Query-%s", queryStateMachine.getQueryId())) { - Set completedStages = new HashSet<>(); - ExecutionSchedule executionSchedule = executionPolicy.createExecutionSchedule(stages.values()); - while (!executionSchedule.isFinished()) { - List> blockedStages = new ArrayList<>(); - for (SqlStageExecution stage : executionSchedule.getStagesToSchedule()) { - stage.beginScheduling(); - - // perform some scheduling work - ScheduleResult result = stageSchedulers.get(stage.getStageId()) - .schedule(); - - // modify parent and children based on the results of the scheduling - if (result.isFinished()) { - stage.schedulingComplete(); - } - else if (!result.getBlocked().isDone()) { - blockedStages.add(result.getBlocked()); - } - stageLinkages.get(stage.getStageId()) - .processScheduleResults(stage.getState(), result.getNewTasks()); - schedulerStats.getSplitsScheduledPerIteration().add(result.getSplitsScheduled()); - if (result.getBlockedReason().isPresent()) { - switch (result.getBlockedReason().get()) { - case WRITER_SCALING: - // no-op - break; - case WAITING_FOR_SOURCE: - schedulerStats.getWaitingForSource().update(1); - break; - case SPLIT_QUEUES_FULL: - schedulerStats.getSplitQueuesFull().update(1); - break; - case MIXED_SPLIT_QUEUES_FULL_AND_WAITING_FOR_SOURCE: - case NO_ACTIVE_DRIVER_GROUP: - break; - default: - throw new UnsupportedOperationException("Unknown blocked reason: " + result.getBlockedReason().get()); - } - } + private static StageManager create( + QueryStateMachine queryStateMachine, + Session session, + Metadata metadata, + RemoteTaskFactory taskFactory, + NodeTaskMap nodeTaskMap, + ExecutorService executor, + SplitSchedulerStats schedulerStats, + SubPlan planTree, + boolean summarizeTaskInfo) + { + ImmutableMap.Builder stages = ImmutableMap.builder(); + ImmutableList.Builder coordinatorStagesInTopologicalOrder = ImmutableList.builder(); + ImmutableList.Builder distributedStagesInTopologicalOrder = ImmutableList.builder(); + StageId rootStageId = null; + ImmutableMap.Builder> children = ImmutableMap.builder(); + ImmutableMap.Builder parents = ImmutableMap.builder(); + for (SubPlan planNode : Traverser.forTree(SubPlan::getChildren).breadthFirst(planTree)) { + PlanFragment fragment = planNode.getFragment(); + SqlStage stage = createSqlStage( + getStageId(session.getQueryId(), fragment.getId()), + fragment, + extractTableInfo(session, metadata, fragment), + taskFactory, + session, + summarizeTaskInfo, + nodeTaskMap, + executor, + schedulerStats); + StageId stageId = stage.getStageId(); + stages.put(stageId, stage); + if (fragment.getPartitioning().isCoordinatorOnly()) { + coordinatorStagesInTopologicalOrder.add(stage); } - - // make sure to update stage linkage at least once per loop to catch async state changes (e.g., partial cancel) - for (SqlStageExecution stage : stages.values()) { - if (!completedStages.contains(stage.getStageId()) && stage.getState().isDone()) { - stageLinkages.get(stage.getStageId()) - .processScheduleResults(stage.getState(), ImmutableSet.of()); - completedStages.add(stage.getStageId()); - } + else { + distributedStagesInTopologicalOrder.add(stage); } - - // wait for a state change and then schedule again - if (!blockedStages.isEmpty()) { - try (TimeStat.BlockTimer timer = schedulerStats.getSleepTime().time()) { - tryGetFutureValue(whenAnyComplete(blockedStages), 1, SECONDS); - } - for (ListenableFuture blockedStage : blockedStages) { - blockedStage.cancel(true); - } + if (rootStageId == null) { + rootStageId = stageId; } + Set childStageIds = planNode.getChildren().stream() + .map(childStage -> getStageId(session.getQueryId(), childStage.getFragment().getId())) + .collect(toImmutableSet()); + children.put(stageId, childStageIds); + childStageIds.forEach(child -> parents.put(child, stageId)); } + StageManager stageManager = new StageManager( + queryStateMachine, + stages.build(), + coordinatorStagesInTopologicalOrder.build(), + distributedStagesInTopologicalOrder.build(), + rootStageId, + children.build(), + parents.build()); + stageManager.initialize(); + return stageManager; + } - for (SqlStageExecution stage : stages.values()) { - StageState state = stage.getState(); - if (state != SCHEDULED && state != RUNNING && state != FLUSHING && !state.isDone()) { - throw new TrinoException(GENERIC_INTERNAL_ERROR, format("Scheduling is complete, but stage %s is in state %s", stage.getStageId(), state)); - } - } + private static Map extractTableInfo(Session session, Metadata metadata, PlanFragment fragment) + { + return searchFrom(fragment.getRoot()) + .where(TableScanNode.class::isInstance) + .findAll() + .stream() + .map(TableScanNode.class::cast) + .collect(toImmutableMap(PlanNode::getId, node -> getTableInfo(session, metadata, node))); } - catch (Throwable t) { - queryStateMachine.transitionToFailed(t); - throw t; + + private static TableInfo getTableInfo(Session session, Metadata metadata, TableScanNode node) + { + TableSchema tableSchema = metadata.getTableSchema(session, node.getTable()); + TableProperties tableProperties = metadata.getTableProperties(session, node.getTable()); + return new TableInfo(tableSchema.getQualifiedName(), tableProperties.getPredicate()); } - finally { - RuntimeException closeError = new RuntimeException(); - for (StageScheduler scheduler : stageSchedulers.values()) { - try { - scheduler.close(); - } - catch (Throwable t) { - queryStateMachine.transitionToFailed(t); - // Self-suppression not permitted - if (closeError != t) { - closeError.addSuppressed(t); - } - } - } - if (closeError.getSuppressed().length > 0) { - throw closeError; + + private static StageId getStageId(QueryId queryId, PlanFragmentId fragmentId) + { + // TODO: refactor fragment id to be based on an integer + return new StageId(queryId, parseInt(fragmentId.toString())); + } + + private StageManager( + QueryStateMachine queryStateMachine, + Map stages, + List coordinatorStagesInTopologicalOrder, + List distributedStagesInTopologicalOrder, + StageId rootStageId, + Map> children, + Map parents) + { + this.queryStateMachine = requireNonNull(queryStateMachine, "queryStateMachine is null"); + this.stages = ImmutableMap.copyOf(requireNonNull(stages, "stages is null")); + this.coordinatorStagesInTopologicalOrder = ImmutableList.copyOf(requireNonNull(coordinatorStagesInTopologicalOrder, "coordinatorStagesInTopologicalOrder is null")); + this.distributedStagesInTopologicalOrder = ImmutableList.copyOf(requireNonNull(distributedStagesInTopologicalOrder, "distributedStagesInTopologicalOrder is null")); + this.rootStageId = requireNonNull(rootStageId, "rootStageId is null"); + this.children = ImmutableMap.copyOf(requireNonNull(children, "children is null")); + this.parents = ImmutableMap.copyOf(requireNonNull(parents, "parents is null")); + } + + // this is a separate method to ensure that the `this` reference is not leaked during construction + private void initialize() + { + for (SqlStage stage : stages.values()) { + stage.addFinalStageInfoListener(status -> queryStateMachine.updateQueryInfo(Optional.ofNullable(getStageInfo()))); } } - } - public void cancelStage(StageId stageId) - { - try (SetThreadName ignored = new SetThreadName("Query-%s", queryStateMachine.getQueryId())) { - SqlStageExecution sqlStageExecution = stages.get(stageId); - SqlStageExecution stage = requireNonNull(sqlStageExecution, () -> format("Stage '%s' does not exist", stageId)); - stage.cancel(); + public void finish() + { + stages.values().forEach(SqlStage::finish); } - } - public void abort() - { - try (SetThreadName ignored = new SetThreadName("Query-%s", queryStateMachine.getQueryId())) { - stages.values().forEach(SqlStageExecution::abort); + public void abort() + { + stages.values().forEach(SqlStage::abort); } - } - private static ListenableFuture whenAllStages(Collection stages, Predicate predicate) - { - checkArgument(!stages.isEmpty(), "stages is empty"); - Set stageIds = stages.stream() - .map(SqlStageExecution::getStageId) - .collect(toCollection(Sets::newConcurrentHashSet)); - SettableFuture future = SettableFuture.create(); - - for (SqlStageExecution stage : stages) { - stage.addStateChangeListener(state -> { - if (predicate.test(state) && stageIds.remove(stage.getStageId()) && stageIds.isEmpty()) { - future.set(null); - } - }); + public List getCoordinatorStagesInTopologicalOrder() + { + return coordinatorStagesInTopologicalOrder; } - return future; - } + public List getDistributedStagesInTopologicalOrder() + { + return distributedStagesInTopologicalOrder; + } - private interface ExchangeLocationsConsumer - { - void addExchangeLocations(PlanFragmentId fragmentId, Set tasks, boolean noMoreExchangeLocations); - } + public SqlStage getOutputStage() + { + return stages.get(rootStageId); + } - private static class StageLinkage - { - private final PlanFragmentId currentStageFragmentId; - private final ExchangeLocationsConsumer parent; - private final Set childOutputBufferManagers; - private final Set childStageIds; - - public StageLinkage(PlanFragmentId fragmentId, ExchangeLocationsConsumer parent, Set children) - { - this.currentStageFragmentId = fragmentId; - this.parent = parent; - this.childOutputBufferManagers = children.stream() - .map(childStage -> { - PartitioningHandle partitioningHandle = childStage.getFragment().getPartitioningScheme().getPartitioning().getHandle(); - if (partitioningHandle.equals(FIXED_BROADCAST_DISTRIBUTION)) { - return new BroadcastOutputBufferManager(childStage::setOutputBuffers); - } - else if (partitioningHandle.equals(SCALED_WRITER_DISTRIBUTION)) { - return new ScaledOutputBufferManager(childStage::setOutputBuffers); - } - else { - int partitionCount = Ints.max(childStage.getFragment().getPartitioningScheme().getBucketToPartition().get()) + 1; - return new PartitionedOutputBufferManager(partitioningHandle, partitionCount, childStage::setOutputBuffers); - } - }) - .collect(toImmutableSet()); + public SqlStage get(PlanFragmentId fragmentId) + { + return get(getStageId(queryStateMachine.getQueryId(), fragmentId)); + } + + public SqlStage get(StageId stageId) + { + return requireNonNull(stages.get(stageId), () -> "stage not found: " + stageId); + } + + public Set getChildren(PlanFragmentId fragmentId) + { + return getChildren(getStageId(queryStateMachine.getQueryId(), fragmentId)); + } - this.childStageIds = children.stream() - .map(SqlStageExecution::getStageId) + public Set getChildren(StageId stageId) + { + return children.get(stageId).stream() + .map(this::get) .collect(toImmutableSet()); } - public Set getChildStageIds() + public Optional getParent(PlanFragmentId fragmentId) { - return childStageIds; + return getParent(getStageId(queryStateMachine.getQueryId(), fragmentId)); } - public void processScheduleResults(StageState newState, Set newTasks) + public Optional getParent(StageId stageId) { - boolean noMoreTasks = !newState.canScheduleMoreTasks(); - // Add an exchange location to the parent stage for each new task - parent.addExchangeLocations(currentStageFragmentId, newTasks, noMoreTasks); + return Optional.ofNullable(parents.get(stageId)).map(stages::get); + } - if (!childOutputBufferManagers.isEmpty()) { - // Add an output buffer to the child stages for each new task - List newOutputBuffers = newTasks.stream() - .map(task -> new OutputBufferId(task.getTaskId().getId())) - .collect(toImmutableList()); - for (OutputBufferManager child : childOutputBufferManagers) { - child.addOutputBuffers(newOutputBuffers, noMoreTasks); - } + public BasicStageStats getBasicStageStats() + { + List stageStats = stages.values().stream() + .map(SqlStage::getBasicStageStats) + .collect(toImmutableList()); + + return aggregateBasicStageStats(stageStats); + } + + public StageInfo getStageInfo() + { + Map stageInfos = stages.values().stream() + .map(SqlStage::getStageInfo) + .collect(toImmutableMap(StageInfo::getStageId, identity())); + + return buildStageInfo(rootStageId, stageInfos); + } + + private StageInfo buildStageInfo(StageId stageId, Map stageInfos) + { + StageInfo parent = stageInfos.get(stageId); + checkArgument(parent != null, "No stageInfo for %s", parent); + List childStages = children.get(stageId).stream() + .map(childStageId -> buildStageInfo(childStageId, stageInfos)) + .collect(toImmutableList()); + if (childStages.isEmpty()) { + return parent; } + return new StageInfo( + parent.getStageId(), + parent.getState(), + parent.getPlan(), + parent.isCoordinatorOnly(), + parent.getTypes(), + parent.getStageStats(), + parent.getTasks(), + childStages, + parent.getTables(), + parent.getFailureCause()); + } + + public long getUserMemoryReservation() + { + return stages.values().stream() + .mapToLong(SqlStage::getUserMemoryReservation) + .sum(); + } + + public long getTotalMemoryReservation() + { + return stages.values().stream() + .mapToLong(SqlStage::getTotalMemoryReservation) + .sum(); + } + + public Duration getTotalCpuTime() + { + long millis = stages.values().stream() + .mapToLong(stage -> stage.getTotalCpuTime().toMillis()) + .sum(); + return new Duration(millis, MILLISECONDS); + } + } + + private static class QueryOutputTaskLifecycleListener + implements TaskLifecycleListener + { + private final QueryStateMachine queryStateMachine; + + private QueryOutputTaskLifecycleListener(QueryStateMachine queryStateMachine) + { + this.queryStateMachine = requireNonNull(queryStateMachine, "queryStateMachine is null"); + } + + @Override + public void taskCreated(PlanFragmentId fragmentId, RemoteTask task) + { + Map bufferLocations = ImmutableMap.of( + task.getTaskId(), + uriBuilderFrom(task.getTaskStatus().getSelf()) + .appendPath("results") + .appendPath("0").build()); + queryStateMachine.updateOutputLocations(bufferLocations, false); + } + + @Override + public void noMoreTasks(PlanFragmentId fragmentId) + { + queryStateMachine.updateOutputLocations(ImmutableMap.of(), true); + } + } + + private static class CoordinatorStagesScheduler + { + private static final int[] SINGLE_PARTITION = new int[] {0}; + + private final QueryStateMachine queryStateMachine; + private final NodeScheduler nodeScheduler; + private final Map outputBuffersForStagesConsumedByCoordinator; + private final Map> bucketToPartitionForStagesConsumedByCoordinator; + private final TaskLifecycleListener taskLifecycleListener; + private final StageManager stageManager; + private final List stageExecutions; + private final AtomicReference distributedStagesScheduler; + private final TaskManager coordinatorTaskManager; + + private final AtomicBoolean scheduled = new AtomicBoolean(); + + public static CoordinatorStagesScheduler create( + QueryStateMachine queryStateMachine, + NodeScheduler nodeScheduler, + StageManager stageManager, + FailureDetector failureDetector, + Executor executor, + AtomicReference distributedStagesScheduler, + TaskManager coordinatorTaskManager) + { + Map outputBuffersForStagesConsumedByCoordinator = createOutputBuffersForStagesConsumedByCoordinator(stageManager); + Map> bucketToPartitionForStagesConsumedByCoordinator = createBucketToPartitionForStagesConsumedByCoordinator(stageManager); + + TaskLifecycleListener taskLifecycleListener = new QueryOutputTaskLifecycleListener(queryStateMachine); + // create executions + ImmutableList.Builder stageExecutions = ImmutableList.builder(); + for (SqlStage stage : stageManager.getCoordinatorStagesInTopologicalOrder()) { + PipelinedStageExecution stageExecution = createPipelinedStageExecution( + stage, + outputBuffersForStagesConsumedByCoordinator, + taskLifecycleListener, + failureDetector, + executor, + bucketToPartitionForStagesConsumedByCoordinator.get(stage.getFragment().getId()), + 0); + stageExecutions.add(stageExecution); + taskLifecycleListener = stageExecution.getTaskLifecycleListener(); + } + + CoordinatorStagesScheduler coordinatorStagesScheduler = new CoordinatorStagesScheduler( + queryStateMachine, + nodeScheduler, + outputBuffersForStagesConsumedByCoordinator, + bucketToPartitionForStagesConsumedByCoordinator, + taskLifecycleListener, + stageManager, + stageExecutions.build(), + distributedStagesScheduler, + coordinatorTaskManager); + coordinatorStagesScheduler.initialize(); + + return coordinatorStagesScheduler; + } + + private static Map createOutputBuffersForStagesConsumedByCoordinator(StageManager stageManager) + { + ImmutableMap.Builder result = ImmutableMap.builder(); + + // create output buffer for output stage + SqlStage outputStage = stageManager.getOutputStage(); + result.put(outputStage.getFragment().getId(), createSingleStreamOutputBuffer(outputStage)); + + // create output buffers for stages consumed by coordinator + for (SqlStage coordinatorStage : stageManager.getCoordinatorStagesInTopologicalOrder()) { + for (SqlStage childStage : stageManager.getChildren(coordinatorStage.getStageId())) { + result.put(childStage.getFragment().getId(), createSingleStreamOutputBuffer(childStage)); + } + } + + return result.build(); + } + + private static OutputBufferManager createSingleStreamOutputBuffer(SqlStage stage) + { + PartitioningHandle partitioningHandle = stage.getFragment().getPartitioningScheme().getPartitioning().getHandle(); + checkArgument(partitioningHandle.isSingleNode(), "partitioning is expected to be single node: " + partitioningHandle); + return new PartitionedOutputBufferManager(partitioningHandle, 1); + } + + private static Map> createBucketToPartitionForStagesConsumedByCoordinator(StageManager stageManager) + { + ImmutableMap.Builder> result = ImmutableMap.builder(); + + SqlStage outputStage = stageManager.getOutputStage(); + result.put(outputStage.getFragment().getId(), Optional.of(SINGLE_PARTITION)); + + for (SqlStage coordinatorStage : stageManager.getCoordinatorStagesInTopologicalOrder()) { + for (SqlStage childStage : stageManager.getChildren(coordinatorStage.getStageId())) { + result.put(childStage.getFragment().getId(), Optional.of(SINGLE_PARTITION)); + } + } + + return result.build(); + } + + private CoordinatorStagesScheduler( + QueryStateMachine queryStateMachine, + NodeScheduler nodeScheduler, + Map outputBuffersForStagesConsumedByCoordinator, + Map> bucketToPartitionForStagesConsumedByCoordinator, + TaskLifecycleListener taskLifecycleListener, + StageManager stageManager, + List stageExecutions, + AtomicReference distributedStagesScheduler, + TaskManager coordinatorTaskManager) + { + this.queryStateMachine = requireNonNull(queryStateMachine, "queryStateMachine is null"); + this.nodeScheduler = requireNonNull(nodeScheduler, "nodeScheduler is null"); + this.outputBuffersForStagesConsumedByCoordinator = ImmutableMap.copyOf(requireNonNull(outputBuffersForStagesConsumedByCoordinator, "outputBuffersForStagesConsumedByCoordinator is null")); + this.bucketToPartitionForStagesConsumedByCoordinator = ImmutableMap.copyOf(requireNonNull(bucketToPartitionForStagesConsumedByCoordinator, "bucketToPartitionForStagesConsumedByCoordinator is null")); + this.taskLifecycleListener = requireNonNull(taskLifecycleListener, "taskLifecycleListener is null"); + this.stageManager = requireNonNull(stageManager, "stageManager is null"); + this.stageExecutions = ImmutableList.copyOf(requireNonNull(stageExecutions, "stageExecutions is null")); + this.distributedStagesScheduler = requireNonNull(distributedStagesScheduler, "distributedStagesScheduler is null"); + this.coordinatorTaskManager = requireNonNull(coordinatorTaskManager, "coordinatorTaskManager is null"); + } + + private void initialize() + { + for (PipelinedStageExecution stageExecution : stageExecutions) { + stageExecution.addStateChangeListener(state -> { + if (queryStateMachine.isDone()) { + return; + } + // if any coordinator stage failed transition directly to failure + if (state == FAILED) { + RuntimeException failureCause = stageExecution.getFailureCause() + .map(ExecutionFailureInfo::toException) + .orElseGet(() -> new VerifyException(format("stage execution for stage %s is failed by failure cause is not present", stageExecution.getStageId()))); + stageManager.get(stageExecution.getStageId()).fail(failureCause); + queryStateMachine.transitionToFailed(failureCause); + } + else if (state == ABORTED) { + // this should never happen, since abort can only be triggered in query clean up after the query is finished + stageManager.get(stageExecution.getStageId()).abort(); + queryStateMachine.transitionToFailed(new TrinoException(GENERIC_INTERNAL_ERROR, "Query stage was aborted")); + } + else if (state.isDone()) { + stageManager.get(stageExecution.getStageId()).finish(); + } + }); + } + + for (int currentIndex = 0, nextIndex = 1; nextIndex < stageExecutions.size(); currentIndex++, nextIndex++) { + PipelinedStageExecution stageExecution = stageExecutions.get(currentIndex); + PipelinedStageExecution childStageExecution = stageExecutions.get(nextIndex); + Set childStages = stageManager.getChildren(stageExecution.getStageId()); + verify(childStages.size() == 1, "exactly one child stage is expected"); + SqlStage childStage = getOnlyElement(childStages); + verify(childStage.getStageId().equals(childStageExecution.getStageId()), "stage execution order doesn't match the stage order"); + stageExecution.addStateChangeListener(newState -> { + if (newState == FLUSHING || newState.isDone()) { + childStageExecution.cancel(); + } + }); + } + + Optional root = Optional.ofNullable(getFirst(stageExecutions, null)); + root.ifPresent(stageExecution -> stageExecution.addStateChangeListener(state -> { + if (state == FINISHED) { + queryStateMachine.transitionToFinishing(); + } + else if (state == CANCELED) { + // output stage was canceled + queryStateMachine.transitionToCanceled(); + } + })); + + Optional last = Optional.ofNullable(getLast(stageExecutions, null)); + last.ifPresent(stageExecution -> stageExecution.addStateChangeListener(newState -> { + if (newState == FLUSHING || newState.isDone()) { + DistributedStagesScheduler distributedStagesScheduler = this.distributedStagesScheduler.get(); + if (distributedStagesScheduler != null) { + distributedStagesScheduler.cancel(); + } + } + })); + } + + public void schedule() + { + if (!scheduled.compareAndSet(false, true)) { + return; + } + + /* + * Tasks have 2 communication links: + * + * Task <-> Coordinator (for status updates) + * Task <-> Downstream Task (for exchanging the task results) + * + * In a scenario when a link between a task and a downstream task is broken (while the link between a + * task and coordinator is not) without failure recovery enabled the downstream task would discover + * that the communication link is broken and fail a query. + * + * However with failure recovery enabled a downstream task is configured to ignore the failures to survive an + * upstream task failure. That may result into a "deadlock", when the coordinator thinks that a task is active, + * but since the communication link between the task and it's downstream task is broken nobody is pooling + * the results leaving it in a blocked state. Thus it is important to notify the scheduler about such + * communication failures so the scheduler can react and re-schedule a task. + * + * Currently only "coordinator" tasks have to survive an upstream task failure (for example a task that performs + * table commit). Restarting a table commit task introduces another set of challenges (such as making sure the commit + * operation is always idempotent). Given that only coordinator tasks have to survive a failure there's a shortcut in + * implementation of the error reporting. The assumption is that scheduling also happens on coordinator, thus no RPC is + * involved in notifying the coordinator. Whenever it is needed to separate scheduling and coordinator tasks on different + * nodes an RPC mechanism for this notification has to be implemented. + * + * Note: For queries that don't have any coordinator stages the situation is still similar. The exchange client that + * pulls the final query results has to propagate the same notification if the communication link between the exchange client + * and one of the output tasks is broken. + */ + TaskFailureReporter failureReporter = new TaskFailureReporter(distributedStagesScheduler); + queryStateMachine.addOutputTaskFailureListener(failureReporter); + + InternalNode coordinator = nodeScheduler.createNodeSelector(queryStateMachine.getSession(), Optional.empty()).selectCurrentNode(); + for (PipelinedStageExecution stageExecution : stageExecutions) { + Optional remoteTask = stageExecution.scheduleTask( + coordinator, + 0, + ImmutableMultimap.of(), + ImmutableMultimap.of()); + stageExecution.schedulingComplete(); + remoteTask.ifPresent(task -> coordinatorTaskManager.addSourceTaskFailureListener(task.getTaskId(), failureReporter)); + } + } + + public Map getOutputBuffersForStagesConsumedByCoordinator() + { + return outputBuffersForStagesConsumedByCoordinator; + } + + public Map> getBucketToPartitionForStagesConsumedByCoordinator() + { + return bucketToPartitionForStagesConsumedByCoordinator; + } + + public TaskLifecycleListener getTaskLifecycleListener() + { + return taskLifecycleListener; + } + + public void cancelStage(StageId stageId) + { + for (PipelinedStageExecution stageExecution : stageExecutions) { + if (stageExecution.getStageId().equals(stageId)) { + stageExecution.cancel(); + } + } + } + + public void cancel() + { + stageExecutions.forEach(PipelinedStageExecution::cancel); + } + + public void abort() + { + stageExecutions.forEach(PipelinedStageExecution::abort); + } + } + + private static class TaskFailureReporter + implements TaskFailureListener + { + private final AtomicReference distributedStagesScheduler; + + private TaskFailureReporter(AtomicReference distributedStagesScheduler) + { + this.distributedStagesScheduler = distributedStagesScheduler; + } + + @Override + public void onTaskFailed(TaskId taskId, Throwable failure) + { + if (failure instanceof TrinoException && ((TrinoException) failure).getErrorCode() == REMOTE_TASK_FAILED.toErrorCode()) { + // This error indicates that a downstream task was trying to fetch results from an upstream task that is marked as failed + // Instead of failing a downstream task let the coordinator handle and report the failure of an upstream task to ensure correct error reporting + log.info("Task failure discovered while fetching task results: %s", taskId); + return; + } + log.warn(failure, "Reported task failure: %s", taskId); + DistributedStagesScheduler scheduler = this.distributedStagesScheduler.get(); + if (scheduler != null) { + scheduler.reportTaskFailure(taskId, failure); + } + } + } + + private interface DistributedStagesScheduler + { + void schedule(); + + void cancelStage(StageId stageId); + + void cancel(); + + void abort(); + + void reportTaskFailure(TaskId taskId, Throwable failureCause); + + void addStateChangeListener(StateChangeListener stateChangeListener); + + Optional getFailureCause(); + } + + private static class PipelinedDistributedStagesScheduler + implements DistributedStagesScheduler + { + private final DistributedStagesSchedulerStateMachine stateMachine; + private final QueryStateMachine queryStateMachine; + private final SplitSchedulerStats schedulerStats; + private final StageManager stageManager; + private final ExecutionSchedule executionSchedule; + private final Map stageSchedulers; + private final Map stageExecutions; + private final DynamicFilterService dynamicFilterService; + + private final AtomicBoolean started = new AtomicBoolean(); + + public static PipelinedDistributedStagesScheduler create( + QueryStateMachine queryStateMachine, + SplitSchedulerStats schedulerStats, + NodeScheduler nodeScheduler, + NodePartitioningManager nodePartitioningManager, + StageManager stageManager, + CoordinatorStagesScheduler coordinatorStagesScheduler, + ExecutionPolicy executionPolicy, + FailureDetector failureDetector, + ScheduledExecutorService executor, + SplitSourceFactory splitSourceFactory, + int splitBatchSize, + DynamicFilterService dynamicFilterService, + TableExecuteContextManager tableExecuteContextManager, + RetryPolicy retryPolicy, + int attempt) + { + DistributedStagesSchedulerStateMachine stateMachine = new DistributedStagesSchedulerStateMachine(queryStateMachine.getQueryId(), executor); + + Map partitioningCacheMap = new HashMap<>(); + Function partitioningCache = partitioningHandle -> + partitioningCacheMap.computeIfAbsent(partitioningHandle, handle -> nodePartitioningManager.getNodePartitioningMap(queryStateMachine.getSession(), handle)); + + Map> bucketToPartitionMap = createBucketToPartitionMap( + coordinatorStagesScheduler.getBucketToPartitionForStagesConsumedByCoordinator(), + stageManager, + partitioningCache); + Map outputBufferManagers = createOutputBufferManagers( + coordinatorStagesScheduler.getOutputBuffersForStagesConsumedByCoordinator(), + stageManager, + bucketToPartitionMap); + + TaskLifecycleListener coordinatorTaskLifecycleListener = coordinatorStagesScheduler.getTaskLifecycleListener(); + if (retryPolicy != RetryPolicy.NONE) { + // when retries are enabled only close exchange clients on coordinator when the query is finished + TaskLifecycleListenerBridge taskLifecycleListenerBridge = new TaskLifecycleListenerBridge(coordinatorTaskLifecycleListener); + coordinatorTaskLifecycleListener = taskLifecycleListenerBridge; + stateMachine.addStateChangeListener(state -> { + if (state == DistributedStagesSchedulerState.FINISHED) { + taskLifecycleListenerBridge.notifyNoMoreSourceTasks(); + } + }); + } + + Map stageExecutions = new HashMap<>(); + for (SqlStage stage : stageManager.getDistributedStagesInTopologicalOrder()) { + Optional parentStage = stageManager.getParent(stage.getStageId()); + TaskLifecycleListener taskLifecycleListener; + if (parentStage.isEmpty() || parentStage.get().getFragment().getPartitioning().isCoordinatorOnly()) { + // output will be consumed by coordinator + taskLifecycleListener = coordinatorTaskLifecycleListener; + } + else { + StageId parentStageId = parentStage.get().getStageId(); + PipelinedStageExecution parentStageExecution = requireNonNull(stageExecutions.get(parentStageId), () -> "execution is null for stage: " + parentStageId); + taskLifecycleListener = parentStageExecution.getTaskLifecycleListener(); + } + + PlanFragment fragment = stage.getFragment(); + PipelinedStageExecution stageExecution = createPipelinedStageExecution( + stageManager.get(fragment.getId()), + outputBufferManagers, + taskLifecycleListener, + failureDetector, + executor, + bucketToPartitionMap.get(fragment.getId()), + attempt); + stageExecutions.put(stage.getStageId(), stageExecution); + } + + ImmutableMap.Builder stageSchedulers = ImmutableMap.builder(); + for (PipelinedStageExecution stageExecution : stageExecutions.values()) { + List children = stageManager.getChildren(stageExecution.getStageId()).stream() + .map(stage -> requireNonNull(stageExecutions.get(stage.getStageId()), () -> "stage execution not found for stage: " + stage)) + .collect(toImmutableList()); + StageScheduler scheduler = createStageScheduler( + queryStateMachine, + stageExecution, + splitSourceFactory, + children, + partitioningCache, + nodeScheduler, + nodePartitioningManager, + splitBatchSize, + dynamicFilterService, + executor, + tableExecuteContextManager); + stageSchedulers.put(stageExecution.getStageId(), scheduler); + } + + PipelinedDistributedStagesScheduler distributedStagesScheduler = new PipelinedDistributedStagesScheduler( + stateMachine, + queryStateMachine, + schedulerStats, + stageManager, + executionPolicy.createExecutionSchedule(stageExecutions.values()), + stageSchedulers.build(), + ImmutableMap.copyOf(stageExecutions), + dynamicFilterService); + distributedStagesScheduler.initialize(); + return distributedStagesScheduler; + } + + private static Map> createBucketToPartitionMap( + Map> bucketToPartitionForStagesConsumedByCoordinator, + StageManager stageManager, + Function partitioningCache) + { + ImmutableMap.Builder> result = ImmutableMap.builder(); + result.putAll(bucketToPartitionForStagesConsumedByCoordinator); + for (SqlStage stage : stageManager.getDistributedStagesInTopologicalOrder()) { + PlanFragment fragment = stage.getFragment(); + Optional bucketToPartition = getBucketToPartition(fragment.getPartitioning(), partitioningCache, fragment.getRoot(), fragment.getRemoteSourceNodes()); + for (SqlStage childStage : stageManager.getChildren(stage.getStageId())) { + result.put(childStage.getFragment().getId(), bucketToPartition); + } + } + return result.build(); + } + + private static Optional getBucketToPartition( + PartitioningHandle partitioningHandle, + Function partitioningCache, + PlanNode fragmentRoot, + List remoteSourceNodes) + { + if (partitioningHandle.equals(SOURCE_DISTRIBUTION) || partitioningHandle.equals(SCALED_WRITER_DISTRIBUTION)) { + return Optional.of(new int[1]); + } + else if (searchFrom(fragmentRoot).where(node -> node instanceof TableScanNode).findFirst().isPresent()) { + if (remoteSourceNodes.stream().allMatch(node -> node.getExchangeType() == REPLICATE)) { + return Optional.empty(); + } + else { + // remote source requires nodePartitionMap + NodePartitionMap nodePartitionMap = partitioningCache.apply(partitioningHandle); + return Optional.of(nodePartitionMap.getBucketToPartition()); + } + } + else { + NodePartitionMap nodePartitionMap = partitioningCache.apply(partitioningHandle); + List partitionToNode = nodePartitionMap.getPartitionToNode(); + // todo this should asynchronously wait a standard timeout period before failing + checkCondition(!partitionToNode.isEmpty(), NO_NODES_AVAILABLE, "No worker nodes available"); + return Optional.of(nodePartitionMap.getBucketToPartition()); + } + } + + private static Map createOutputBufferManagers( + Map outputBuffersForStagesConsumedByCoordinator, + StageManager stageManager, + Map> bucketToPartitionMap) + { + ImmutableMap.Builder result = ImmutableMap.builder(); + result.putAll(outputBuffersForStagesConsumedByCoordinator); + for (SqlStage parentStage : stageManager.getDistributedStagesInTopologicalOrder()) { + for (SqlStage childStage : stageManager.getChildren(parentStage.getStageId())) { + PlanFragmentId fragmentId = childStage.getFragment().getId(); + PartitioningHandle partitioningHandle = childStage.getFragment().getPartitioningScheme().getPartitioning().getHandle(); + + OutputBufferManager outputBufferManager; + if (partitioningHandle.equals(FIXED_BROADCAST_DISTRIBUTION)) { + outputBufferManager = new BroadcastOutputBufferManager(); + } + else if (partitioningHandle.equals(SCALED_WRITER_DISTRIBUTION)) { + outputBufferManager = new ScaledOutputBufferManager(); + } + else { + Optional bucketToPartition = bucketToPartitionMap.get(fragmentId); + checkArgument(bucketToPartition.isPresent(), "bucketToPartition is expected to be present for fragment: %s", fragmentId); + int partitionCount = Ints.max(bucketToPartition.get()) + 1; + outputBufferManager = new PartitionedOutputBufferManager(partitioningHandle, partitionCount); + } + result.put(fragmentId, outputBufferManager); + } + } + return result.build(); + } + + private static StageScheduler createStageScheduler( + QueryStateMachine queryStateMachine, + PipelinedStageExecution stageExecution, + SplitSourceFactory splitSourceFactory, + List childStageExecutions, + Function partitioningCache, + NodeScheduler nodeScheduler, + NodePartitioningManager nodePartitioningManager, + int splitBatchSize, + DynamicFilterService dynamicFilterService, + ScheduledExecutorService executor, + TableExecuteContextManager tableExecuteContextManager) + { + Session session = queryStateMachine.getSession(); + PlanFragment fragment = stageExecution.getFragment(); + PartitioningHandle partitioningHandle = fragment.getPartitioning(); + Map splitSources = splitSourceFactory.createSplitSources(session, fragment); + if (!splitSources.isEmpty()) { + queryStateMachine.addStateChangeListener(new StateChangeListener<>() + { + private final AtomicReference> splitSourcesReference = new AtomicReference<>(splitSources.values()); + + @Override + public void stateChanged(QueryState newState) + { + if (newState.isDone()) { + // ensure split sources are closed and release memory + Collection sources = splitSourcesReference.getAndSet(null); + if (sources != null) { + closeSplitSources(sources); + } + } + } + }); + } + if (partitioningHandle.equals(SOURCE_DISTRIBUTION)) { + // nodes are selected dynamically based on the constraints of the splits and the system load + Entry entry = getOnlyElement(splitSources.entrySet()); + PlanNodeId planNodeId = entry.getKey(); + SplitSource splitSource = entry.getValue(); + Optional catalogName = Optional.of(splitSource.getCatalogName()) + .filter(catalog -> !isInternalSystemConnector(catalog)); + NodeSelector nodeSelector = nodeScheduler.createNodeSelector(session, catalogName); + SplitPlacementPolicy placementPolicy = new DynamicSplitPlacementPolicy(nodeSelector, stageExecution::getAllTasks); + + checkArgument(!fragment.getStageExecutionDescriptor().isStageGroupedExecution()); + + return newSourcePartitionedSchedulerAsStageScheduler( + stageExecution, + planNodeId, + splitSource, + placementPolicy, + splitBatchSize, + dynamicFilterService, + tableExecuteContextManager, + () -> childStageExecutions.stream().anyMatch(PipelinedStageExecution::isAnyTaskBlocked)); + } + else if (partitioningHandle.equals(SCALED_WRITER_DISTRIBUTION)) { + Supplier> sourceTasksProvider = () -> childStageExecutions.stream() + .map(PipelinedStageExecution::getTaskStatuses) + .flatMap(List::stream) + .collect(toImmutableList()); + Supplier> writerTasksProvider = stageExecution::getTaskStatuses; + + ScaledWriterScheduler scheduler = new ScaledWriterScheduler( + stageExecution, + sourceTasksProvider, + writerTasksProvider, + nodeScheduler.createNodeSelector(session, Optional.empty()), + executor, + getWriterMinSize(session)); + + whenAllStages(childStageExecutions, PipelinedStageExecution.State::isDone) + .addListener(scheduler::finish, directExecutor()); + + return scheduler; + } + else { + if (!splitSources.isEmpty()) { + // contains local source + List schedulingOrder = fragment.getPartitionedSources(); + Optional catalogName = partitioningHandle.getConnectorId(); + checkArgument(catalogName.isPresent(), "No connector ID for partitioning handle: %s", partitioningHandle); + List connectorPartitionHandles; + boolean groupedExecutionForStage = fragment.getStageExecutionDescriptor().isStageGroupedExecution(); + if (groupedExecutionForStage) { + connectorPartitionHandles = nodePartitioningManager.listPartitionHandles(session, partitioningHandle); + checkState(!ImmutableList.of(NOT_PARTITIONED).equals(connectorPartitionHandles)); + } + else { + connectorPartitionHandles = ImmutableList.of(NOT_PARTITIONED); + } + + BucketNodeMap bucketNodeMap; + List stageNodeList; + if (fragment.getRemoteSourceNodes().stream().allMatch(node -> node.getExchangeType() == REPLICATE)) { + // no remote source + boolean dynamicLifespanSchedule = fragment.getStageExecutionDescriptor().isDynamicLifespanSchedule(); + bucketNodeMap = nodePartitioningManager.getBucketNodeMap(session, partitioningHandle, dynamicLifespanSchedule); + + // verify execution is consistent with planner's decision on dynamic lifespan schedule + verify(bucketNodeMap.isDynamic() == dynamicLifespanSchedule); + + stageNodeList = new ArrayList<>(nodeScheduler.createNodeSelector(session, catalogName).allNodes()); + Collections.shuffle(stageNodeList); + } + else { + // cannot use dynamic lifespan schedule + verify(!fragment.getStageExecutionDescriptor().isDynamicLifespanSchedule()); + + // remote source requires nodePartitionMap + NodePartitionMap nodePartitionMap = partitioningCache.apply(partitioningHandle); + if (groupedExecutionForStage) { + checkState(connectorPartitionHandles.size() == nodePartitionMap.getBucketToPartition().length); + } + stageNodeList = nodePartitionMap.getPartitionToNode(); + bucketNodeMap = nodePartitionMap.asBucketNodeMap(); + } + + return new FixedSourcePartitionedScheduler( + stageExecution, + splitSources, + fragment.getStageExecutionDescriptor(), + schedulingOrder, + stageNodeList, + bucketNodeMap, + splitBatchSize, + getConcurrentLifespansPerNode(session), + nodeScheduler.createNodeSelector(session, catalogName), + connectorPartitionHandles, + dynamicFilterService, + tableExecuteContextManager); + } + else { + // all sources are remote + NodePartitionMap nodePartitionMap = partitioningCache.apply(partitioningHandle); + List partitionToNode = nodePartitionMap.getPartitionToNode(); + // todo this should asynchronously wait a standard timeout period before failing + checkCondition(!partitionToNode.isEmpty(), NO_NODES_AVAILABLE, "No worker nodes available"); + return new FixedCountScheduler(stageExecution, partitionToNode); + } + } + } + + private static void closeSplitSources(Collection splitSources) + { + for (SplitSource source : splitSources) { + try { + source.close(); + } + catch (Throwable t) { + log.warn(t, "Error closing split source"); + } + } + } + + private static ListenableFuture whenAllStages(Collection stages, Predicate predicate) + { + checkArgument(!stages.isEmpty(), "stages is empty"); + Set stageIds = stages.stream() + .map(PipelinedStageExecution::getStageId) + .collect(toCollection(Sets::newConcurrentHashSet)); + SettableFuture future = SettableFuture.create(); + + for (PipelinedStageExecution stageExecution : stages) { + stageExecution.addStateChangeListener(state -> { + if (predicate.test(state) && stageIds.remove(stageExecution.getStageId()) && stageIds.isEmpty()) { + future.set(null); + } + }); + } + + return future; + } + + private PipelinedDistributedStagesScheduler( + DistributedStagesSchedulerStateMachine stateMachine, + QueryStateMachine queryStateMachine, + SplitSchedulerStats schedulerStats, + StageManager stageManager, + ExecutionSchedule executionSchedule, + Map stageSchedulers, + Map stageExecutions, + DynamicFilterService dynamicFilterService) + { + this.stateMachine = requireNonNull(stateMachine, "stateMachine is null"); + this.queryStateMachine = requireNonNull(queryStateMachine, "queryStateMachine is null"); + this.schedulerStats = requireNonNull(schedulerStats, "schedulerStats is null"); + this.stageManager = requireNonNull(stageManager, "stageManager is null"); + this.executionSchedule = requireNonNull(executionSchedule, "executionSchedule is null"); + this.stageSchedulers = ImmutableMap.copyOf(requireNonNull(stageSchedulers, "stageSchedulers is null")); + this.stageExecutions = ImmutableMap.copyOf(requireNonNull(stageExecutions, "stageExecutions is null")); + this.dynamicFilterService = requireNonNull(dynamicFilterService, "dynamicFilterService is null"); + } + + private void initialize() + { + for (PipelinedStageExecution stageExecution : stageExecutions.values()) { + List childStageExecutions = stageManager.getChildren(stageExecution.getStageId()).stream() + .map(stage -> requireNonNull(stageExecutions.get(stage.getStageId()), () -> "stage execution not found for stage: " + stage)) + .collect(toImmutableList()); + if (!childStageExecutions.isEmpty()) { + stageExecution.addStateChangeListener(newState -> { + if (newState == FLUSHING || newState.isDone()) { + childStageExecutions.forEach(PipelinedStageExecution::cancel); + } + }); + } + } + + Set finishedStages = newConcurrentHashSet(); + for (PipelinedStageExecution stageExecution : stageExecutions.values()) { + stageExecution.addStateChangeListener(state -> { + if (stateMachine.getState().isDone()) { + return; + } + int numberOfTasks = stageExecution.getAllTasks().size(); + // TODO: support dynamic filter for failure retries + if (!state.canScheduleMoreTasks()) { + dynamicFilterService.stageCannotScheduleMoreTasks(stageExecution.getStageId(), numberOfTasks); + } + if (numberOfTasks != 0) { + stateMachine.transitionToRunning(); + } + if (state == FAILED) { + RuntimeException failureCause = stageExecution.getFailureCause() + .map(ExecutionFailureInfo::toException) + .orElseGet(() -> new VerifyException(format("stage execution for stage %s is failed by failure cause is not present", stageExecution.getStageId()))); + fail(failureCause, Optional.of(stageExecution.getStageId())); + } + else if (state.isDone()) { + finishedStages.add(stageExecution.getStageId()); + if (finishedStages.containsAll(stageExecutions.keySet())) { + stateMachine.transitionToFinished(); + } + } + }); + } + } + + @Override + public void schedule() + { + checkState(started.compareAndSet(false, true), "already started"); + + try (SetThreadName ignored = new SetThreadName("Query-%s", queryStateMachine.getQueryId())) { + while (!executionSchedule.isFinished()) { + List> blockedStages = new ArrayList<>(); + for (PipelinedStageExecution stageExecution : executionSchedule.getStagesToSchedule()) { + stageExecution.beginScheduling(); + + // perform some scheduling work + ScheduleResult result = stageSchedulers.get(stageExecution.getStageId()) + .schedule(); + + // modify parent and children based on the results of the scheduling + if (result.isFinished()) { + stageExecution.schedulingComplete(); + } + else if (!result.getBlocked().isDone()) { + blockedStages.add(result.getBlocked()); + } + schedulerStats.getSplitsScheduledPerIteration().add(result.getSplitsScheduled()); + if (result.getBlockedReason().isPresent()) { + switch (result.getBlockedReason().get()) { + case WRITER_SCALING: + // no-op + break; + case WAITING_FOR_SOURCE: + schedulerStats.getWaitingForSource().update(1); + break; + case SPLIT_QUEUES_FULL: + schedulerStats.getSplitQueuesFull().update(1); + break; + case MIXED_SPLIT_QUEUES_FULL_AND_WAITING_FOR_SOURCE: + case NO_ACTIVE_DRIVER_GROUP: + break; + default: + throw new UnsupportedOperationException("Unknown blocked reason: " + result.getBlockedReason().get()); + } + } + } + + // wait for a state change and then schedule again + if (!blockedStages.isEmpty()) { + try (TimeStat.BlockTimer timer = schedulerStats.getSleepTime().time()) { + tryGetFutureValue(whenAnyComplete(blockedStages), 1, SECONDS); + } + for (ListenableFuture blockedStage : blockedStages) { + blockedStage.cancel(true); + } + } + } + + for (PipelinedStageExecution stageExecution : stageExecutions.values()) { + PipelinedStageExecution.State state = stageExecution.getState(); + if (state != SCHEDULED && state != RUNNING && state != FLUSHING && !state.isDone()) { + throw new TrinoException(GENERIC_INTERNAL_ERROR, format("Scheduling is complete, but stage %s is in state %s", stageExecution.getStageId(), state)); + } + } + } + catch (Throwable t) { + fail(t, Optional.empty()); + } + finally { + RuntimeException closeError = new RuntimeException(); + for (StageScheduler scheduler : stageSchedulers.values()) { + try { + scheduler.close(); + } + catch (Throwable t) { + fail(t, Optional.empty()); + // Self-suppression not permitted + if (closeError != t) { + closeError.addSuppressed(t); + } + } + } + } + } + + @Override + public void cancelStage(StageId stageId) + { + PipelinedStageExecution stageExecution = stageExecutions.get(stageId); + if (stageExecution != null) { + stageExecution.cancel(); + } + } + + @Override + public void cancel() + { + stateMachine.transitionToCanceled(); + stageExecutions.values().forEach(PipelinedStageExecution::cancel); + } + + @Override + public void abort() + { + stateMachine.transitionToAborted(); + stageExecutions.values().forEach(PipelinedStageExecution::abort); + } + + public void fail(Throwable failureCause, Optional failedStageId) + { + stateMachine.transitionToFailed(failureCause, failedStageId); + stageExecutions.values().forEach(PipelinedStageExecution::abort); + } + + @Override + public void reportTaskFailure(TaskId taskId, Throwable failureCause) + { + PipelinedStageExecution stageExecution = stageExecutions.get(taskId.getStageId()); + if (stageExecution == null) { + return; + } + + List tasks = stageExecution.getAllTasks(); + if (tasks.stream().noneMatch(task -> task.getTaskId().equals(taskId))) { + return; + } + + stageExecution.failTask(taskId, failureCause); + stateMachine.transitionToFailed(failureCause, Optional.of(taskId.getStageId())); + stageExecutions.values().forEach(PipelinedStageExecution::abort); + } + + @Override + public void addStateChangeListener(StateChangeListener stateChangeListener) + { + stateMachine.addStateChangeListener(stateChangeListener); + } + + @Override + public Optional getFailureCause() + { + return stateMachine.getFailureCause(); + } + } + + private enum DistributedStagesSchedulerState + { + PLANNED(false, false), + RUNNING(false, false), + FINISHED(true, false), + CANCELED(true, false), + ABORTED(true, true), + FAILED(true, true); + + public static final Set TERMINAL_STATES = Stream.of(DistributedStagesSchedulerState.values()).filter(DistributedStagesSchedulerState::isDone).collect(toImmutableSet()); + + private final boolean doneState; + private final boolean failureState; + + DistributedStagesSchedulerState(boolean doneState, boolean failureState) + { + checkArgument(!failureState || doneState, "%s is a non-done failure state", name()); + this.doneState = doneState; + this.failureState = failureState; + } + + /** + * Is this a terminal state. + */ + public boolean isDone() + { + return doneState; + } + + /** + * Is this a non-success terminal state. + */ + public boolean isFailure() + { + return failureState; + } + + public boolean isRunningOrDone() + { + return this == RUNNING || isDone(); + } + } + + private static class DistributedStagesSchedulerStateMachine + { + private final QueryId queryId; + private final StateMachine state; + private final AtomicReference failureCause = new AtomicReference<>(); + + public DistributedStagesSchedulerStateMachine(QueryId queryId, Executor executor) + { + this.queryId = requireNonNull(queryId, "queryId is null"); + requireNonNull(executor, "executor is null"); + state = new StateMachine<>("Distributed stages scheduler", executor, DistributedStagesSchedulerState.PLANNED, DistributedStagesSchedulerState.TERMINAL_STATES); + } + + public DistributedStagesSchedulerState getState() + { + return state.get(); + } + + public boolean transitionToRunning() + { + return state.setIf(DistributedStagesSchedulerState.RUNNING, currentState -> !currentState.isDone()); + } + + public boolean transitionToFinished() + { + return state.setIf(DistributedStagesSchedulerState.FINISHED, currentState -> !currentState.isDone()); + } + + public boolean transitionToCanceled() + { + return state.setIf(DistributedStagesSchedulerState.CANCELED, currentState -> !currentState.isDone()); + } + + public boolean transitionToAborted() + { + return state.setIf(DistributedStagesSchedulerState.ABORTED, currentState -> !currentState.isDone()); + } + + public boolean transitionToFailed(Throwable throwable, Optional failedStageId) + { + requireNonNull(throwable, "throwable is null"); + + failureCause.compareAndSet(null, new StageFailureInfo(toFailure(throwable), failedStageId)); + boolean failed = state.setIf(DistributedStagesSchedulerState.FAILED, currentState -> !currentState.isDone()); + if (failed) { + log.error(throwable, "Failure in distributed stage for query %s", queryId); + } + else { + log.debug(throwable, "Failure in distributed stage for query %s after finished", queryId); + } + return failed; + } + + public Optional getFailureCause() + { + return Optional.ofNullable(failureCause.get()); + } + + /** + * Listener is always notified asynchronously using a dedicated notification thread pool so, care should + * be taken to avoid leaking {@code this} when adding a listener in a constructor. Additionally, it is + * possible notifications are observed out of order due to the asynchronous execution. + */ + public void addStateChangeListener(StateChangeListener stateChangeListener) + { + state.addStateChangeListener(stateChangeListener); + } + } + + private static class TaskLifecycleListenerBridge + implements TaskLifecycleListener + { + private final TaskLifecycleListener listener; + + @GuardedBy("this") + private final Set noMoreSourceTasks = new HashSet<>(); + @GuardedBy("this") + private boolean done; + + private TaskLifecycleListenerBridge(TaskLifecycleListener listener) + { + this.listener = requireNonNull(listener, "listener is null"); + } + + @Override + public synchronized void taskCreated(PlanFragmentId fragmentId, RemoteTask task) + { + checkState(!done, "unexpected state"); + listener.taskCreated(fragmentId, task); + } + + @Override + public synchronized void noMoreTasks(PlanFragmentId fragmentId) + { + checkState(!done, "unexpected state"); + noMoreSourceTasks.add(fragmentId); + } + + public synchronized void notifyNoMoreSourceTasks() + { + checkState(!done, "unexpected state"); + done = true; + noMoreSourceTasks.forEach(listener::noMoreTasks); + } + } + + private static class StageFailureInfo + { + private final ExecutionFailureInfo failureInfo; + private final Optional failedStageId; + + private StageFailureInfo(ExecutionFailureInfo failureInfo, Optional failedStageId) + { + this.failureInfo = requireNonNull(failureInfo, "failureInfo is null"); + this.failedStageId = requireNonNull(failedStageId, "failedStageId is null"); + } + + public ExecutionFailureInfo getFailureInfo() + { + return failureInfo; + } + + public Optional getFailedStageId() + { + return failedStageId; } } } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/TaskLifecycleListener.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/TaskLifecycleListener.java new file mode 100644 index 000000000000..999eede414e2 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/TaskLifecycleListener.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import io.trino.execution.RemoteTask; +import io.trino.sql.planner.plan.PlanFragmentId; + +public interface TaskLifecycleListener +{ + void taskCreated(PlanFragmentId fragmentId, RemoteTask task); + + void noMoreTasks(PlanFragmentId fragmentId); + + TaskLifecycleListener NO_OP = new TaskLifecycleListener() + { + @Override + public void taskCreated(PlanFragmentId fragmentId, RemoteTask task) + { + } + + @Override + public void noMoreTasks(PlanFragmentId fragmentId) + { + } + }; +} diff --git a/core/trino-main/src/main/java/io/trino/operator/AssignUniqueIdOperator.java b/core/trino-main/src/main/java/io/trino/operator/AssignUniqueIdOperator.java index 8c4484e82733..6bc2835ba442 100644 --- a/core/trino-main/src/main/java/io/trino/operator/AssignUniqueIdOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/AssignUniqueIdOperator.java @@ -125,7 +125,7 @@ private AssignUniqueId(TaskId taskId, AtomicLong rowIdPool) { this.rowIdPool = requireNonNull(rowIdPool, "rowIdPool is null"); - uniqueValueMask = (((long) taskId.getStageId().getId()) << 54) | (((long) taskId.getId()) << 40); + uniqueValueMask = (((long) taskId.getStageId().getId()) << 54) | (((long) taskId.getPartitionId()) << 40); requestValues(); } diff --git a/core/trino-main/src/main/java/io/trino/operator/DeduplicationExchangeClientBuffer.java b/core/trino-main/src/main/java/io/trino/operator/DeduplicationExchangeClientBuffer.java new file mode 100644 index 000000000000..54c8c13522f2 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/DeduplicationExchangeClientBuffer.java @@ -0,0 +1,354 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator; + +import com.google.common.collect.LinkedListMultimap; +import com.google.common.collect.ListMultimap; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; +import io.airlift.units.DataSize; +import io.trino.execution.TaskId; +import io.trino.execution.buffer.SerializedPage; +import io.trino.spi.TrinoException; + +import javax.annotation.concurrent.GuardedBy; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.Executor; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Throwables.throwIfUnchecked; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.util.concurrent.Futures.nonCancellationPropagating; +import static io.trino.operator.RetryPolicy.QUERY; +import static io.trino.operator.RetryPolicy.TASK; +import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; +import static java.lang.Math.max; +import static java.util.Objects.requireNonNull; + +public class DeduplicationExchangeClientBuffer + implements ExchangeClientBuffer +{ + private final Executor executor; + private final long bufferCapacityInBytes; + private final RetryPolicy retryPolicy; + + private final SettableFuture blocked = SettableFuture.create(); + @GuardedBy("this") + private final Set allTasks = new HashSet<>(); + @GuardedBy("this") + private boolean noMoreTasks; + @GuardedBy("this") + private final Set successfulTasks = new HashSet<>(); + @GuardedBy("this") + private final Map failedTasks = new HashMap<>(); + @GuardedBy("this") + private boolean inputFinished; + @GuardedBy("this") + private Throwable failure; + + @GuardedBy("this") + private final ListMultimap pageBuffer = LinkedListMultimap.create(); + @GuardedBy("this") + private Iterator pagesIterator; + @GuardedBy("this") + private volatile long bufferRetainedSizeInBytes; + @GuardedBy("this") + private volatile long maxBufferRetainedSizeInBytes; + @GuardedBy("this") + private int maxAttemptId; + + @GuardedBy("this") + private boolean closed; + + public DeduplicationExchangeClientBuffer(Executor executor, DataSize bufferCapacity, RetryPolicy retryPolicy) + { + this.executor = requireNonNull(executor, "executor is null"); + this.bufferCapacityInBytes = requireNonNull(bufferCapacity, "bufferCapacity is null").toBytes(); + requireNonNull(retryPolicy, "retryPolicy is null"); + checkArgument(retryPolicy == QUERY, "retryPolicy is expected to be QUERY: %s", retryPolicy); + this.retryPolicy = retryPolicy; + } + + @Override + public ListenableFuture isBlocked() + { + return nonCancellationPropagating(blocked); + } + + @Override + public synchronized SerializedPage pollPage() + { + throwIfFailed(); + + if (closed) { + return null; + } + + if (!inputFinished) { + return null; + } + + if (pagesIterator == null) { + pagesIterator = pageBuffer.values().iterator(); + } + + if (!pagesIterator.hasNext()) { + return null; + } + + SerializedPage page = pagesIterator.next(); + pagesIterator.remove(); + bufferRetainedSizeInBytes -= page.getRetainedSizeInBytes(); + + return page; + } + + @Override + public synchronized void addTask(TaskId taskId) + { + if (closed) { + return; + } + + checkState(!noMoreTasks, "no more tasks expected"); + checkState(allTasks.add(taskId), "task already registered: %s", taskId); + + if (taskId.getAttemptId() > maxAttemptId) { + maxAttemptId = taskId.getAttemptId(); + + if (retryPolicy == QUERY) { + removePagesForPreviousAttempts(taskId.getAttemptId()); + } + } + } + + @Override + public synchronized void addPages(TaskId taskId, List pages) + { + if (closed) { + return; + } + + checkState(allTasks.contains(taskId), "task is not registered: %s", taskId); + checkState(!successfulTasks.contains(taskId), "task is finished: %s", taskId); + checkState(!failedTasks.containsKey(taskId), "task is failed: %s", taskId); + + if (failure != null) { + return; + } + + if (retryPolicy == QUERY && taskId.getAttemptId() < maxAttemptId) { + return; + } + + long pagesRetainedSizeInBytes = 0; + for (SerializedPage page : pages) { + pagesRetainedSizeInBytes += page.getRetainedSizeInBytes(); + } + bufferRetainedSizeInBytes += pagesRetainedSizeInBytes; + if (bufferRetainedSizeInBytes > bufferCapacityInBytes) { + // TODO: implement disk spilling + failure = new TrinoException(NOT_SUPPORTED, "Retries for queries with large result set currently unsupported"); + pageBuffer.clear(); + bufferRetainedSizeInBytes = 0; + unblock(blocked); + return; + } + maxBufferRetainedSizeInBytes = max(maxBufferRetainedSizeInBytes, bufferRetainedSizeInBytes); + pageBuffer.putAll(taskId, pages); + } + + @Override + public synchronized void taskFinished(TaskId taskId) + { + if (closed) { + return; + } + + checkState(allTasks.contains(taskId), "task is not registered: %s", taskId); + checkState(!failedTasks.containsKey(taskId), "task is failed: %s", taskId); + checkState(successfulTasks.add(taskId), "task is finished: %s", taskId); + + if (retryPolicy == TASK) { + // TODO implement deduplication for task level retries + throw new UnsupportedOperationException("task level retry policy is unsupported"); + } + checkInputFinished(); + } + + @Override + public synchronized void taskFailed(TaskId taskId, Throwable t) + { + if (closed) { + return; + } + + checkState(allTasks.contains(taskId), "task is not registered: %s", taskId); + checkState(!successfulTasks.contains(taskId), "task is finished: %s", taskId); + checkState(failedTasks.put(taskId, t) == null, "task is already failed: %s", taskId); + checkInputFinished(); + } + + @Override + public synchronized void noMoreTasks() + { + if (closed) { + return; + } + + noMoreTasks = true; + checkInputFinished(); + } + + private synchronized void checkInputFinished() + { + if (failure != null) { + return; + } + + if (inputFinished) { + return; + } + + if (!noMoreTasks) { + return; + } + + if (allTasks.isEmpty()) { + inputFinished = true; + unblock(blocked); + return; + } + + switch (retryPolicy) { + case TASK: + // TODO implement deduplication for task level retries + throw new UnsupportedOperationException("task level retry policy is unsupported"); + case QUERY: { + Set latestAttemptTasks = allTasks.stream() + .filter(taskId -> taskId.getAttemptId() == maxAttemptId) + .collect(toImmutableSet()); + + if (successfulTasks.containsAll(latestAttemptTasks)) { + removePagesForPreviousAttempts(maxAttemptId); + inputFinished = true; + unblock(blocked); + return; + } + + List failures = failedTasks.entrySet().stream() + .filter(entry -> entry.getKey().getAttemptId() == maxAttemptId) + .map(Map.Entry::getValue) + .collect(toImmutableList()); + + if (!failures.isEmpty()) { + Throwable failure = null; + for (Throwable taskFailure : failures) { + if (failure == null) { + failure = taskFailure; + } + else if (failure != taskFailure) { + failure.addSuppressed(taskFailure); + } + } + pageBuffer.clear(); + bufferRetainedSizeInBytes = 0; + this.failure = failure; + unblock(blocked); + } + break; + } + default: + throw new UnsupportedOperationException("unexpected retry policy: " + retryPolicy); + } + } + + private synchronized void removePagesForPreviousAttempts(int currentAttemptId) + { + // wipe previous attempt pages + long pagesRetainedSizeInBytes = 0; + Iterator> iterator = pageBuffer.entries().iterator(); + while (iterator.hasNext()) { + Map.Entry entry = iterator.next(); + if (entry.getKey().getAttemptId() < currentAttemptId) { + pagesRetainedSizeInBytes += entry.getValue().getRetainedSizeInBytes(); + iterator.remove(); + } + } + bufferRetainedSizeInBytes -= pagesRetainedSizeInBytes; + } + + @Override + public synchronized boolean isFinished() + { + return closed || failure != null || (inputFinished && pageBuffer.isEmpty()); + } + + @Override + public long getRemainingCapacityInBytes() + { + return max(bufferCapacityInBytes - bufferRetainedSizeInBytes, 0); + } + + @Override + public long getRetainedSizeInBytes() + { + return bufferRetainedSizeInBytes; + } + + @Override + public long getMaxRetainedSizeInBytes() + { + return maxBufferRetainedSizeInBytes; + } + + @Override + public synchronized int getBufferedPageCount() + { + return pageBuffer.size(); + } + + @Override + public synchronized void close() + { + if (closed) { + return; + } + closed = true; + pageBuffer.clear(); + bufferRetainedSizeInBytes = 0; + unblock(blocked); + } + + private synchronized void throwIfFailed() + { + if (failure != null) { + throwIfUnchecked(failure); + throw new RuntimeException(failure); + } + } + + private void unblock(SettableFuture blocked) + { + executor.execute(() -> blocked.set(null)); + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/ExchangeClient.java b/core/trino-main/src/main/java/io/trino/operator/ExchangeClient.java index 967d2d6ef7aa..15d793b4200a 100644 --- a/core/trino-main/src/main/java/io/trino/operator/ExchangeClient.java +++ b/core/trino-main/src/main/java/io/trino/operator/ExchangeClient.java @@ -15,12 +15,12 @@ import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListenableFuture; -import com.google.common.util.concurrent.SettableFuture; import io.airlift.http.client.HttpClient; import io.airlift.units.DataSize; import io.airlift.units.Duration; import io.trino.FeaturesConfig.DataIntegrityVerification; -import io.trino.execution.buffer.PageCodecMarker; +import io.trino.execution.TaskFailureListener; +import io.trino.execution.TaskId; import io.trino.execution.buffer.SerializedPage; import io.trino.memory.context.LocalMemoryContext; import io.trino.operator.HttpPageBufferClient.ClientCallback; @@ -32,7 +32,6 @@ import java.io.Closeable; import java.net.URI; -import java.util.ArrayList; import java.util.Deque; import java.util.LinkedList; import java.util.List; @@ -40,34 +39,25 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.Executor; -import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; import static com.google.common.base.Preconditions.checkState; -import static com.google.common.base.Throwables.throwIfUnchecked; import static com.google.common.collect.Sets.newConcurrentHashSet; -import static com.google.common.util.concurrent.Futures.immediateVoidFuture; -import static io.airlift.slice.Slices.EMPTY_SLICE; import static java.util.Objects.requireNonNull; @ThreadSafe public class ExchangeClient implements Closeable { - private static final SerializedPage NO_MORE_PAGES = new SerializedPage(EMPTY_SLICE, PageCodecMarker.MarkerSet.empty(), 0, 0); - private static final ListenableFuture NOT_BLOCKED = immediateVoidFuture(); - private final String selfAddress; private final DataIntegrityVerification dataIntegrityVerification; - private final long bufferCapacity; private final DataSize maxResponseSize; private final int concurrentRequestMultiplier; private final Duration maxErrorDuration; private final boolean acknowledgePages; private final HttpClient httpClient; - private final ScheduledExecutorService scheduler; + private final ScheduledExecutorService scheduledExecutor; @GuardedBy("this") private boolean noMoreLocations; @@ -78,53 +68,47 @@ public class ExchangeClient private final Deque queuedClients = new LinkedList<>(); private final Set completedClients = newConcurrentHashSet(); - private final LinkedBlockingDeque pageBuffer = new LinkedBlockingDeque<>(); - - @GuardedBy("this") - private final List> blockedCallers = new ArrayList<>(); + private final ExchangeClientBuffer buffer; - @GuardedBy("this") - private long bufferRetainedSizeInBytes; - @GuardedBy("this") - private long maxBufferRetainedSizeInBytes; @GuardedBy("this") private long successfulRequests; @GuardedBy("this") private long averageBytesPerRequest; private final AtomicBoolean closed = new AtomicBoolean(); - private final AtomicReference failure = new AtomicReference<>(); private final LocalMemoryContext systemMemoryContext; private final Executor pageBufferClientCallbackExecutor; + private final TaskFailureListener taskFailureListener; // ExchangeClientStatus.mergeWith assumes all clients have the same bufferCapacity. // Please change that method accordingly when this assumption becomes not true. public ExchangeClient( String selfAddress, DataIntegrityVerification dataIntegrityVerification, - DataSize bufferCapacity, + ExchangeClientBuffer buffer, DataSize maxResponseSize, int concurrentRequestMultiplier, Duration maxErrorDuration, boolean acknowledgePages, HttpClient httpClient, - ScheduledExecutorService scheduler, + ScheduledExecutorService scheduledExecutor, LocalMemoryContext systemMemoryContext, - Executor pageBufferClientCallbackExecutor) + Executor pageBufferClientCallbackExecutor, + TaskFailureListener taskFailureListener) { this.selfAddress = requireNonNull(selfAddress, "selfAddress is null"); this.dataIntegrityVerification = requireNonNull(dataIntegrityVerification, "dataIntegrityVerification is null"); - this.bufferCapacity = bufferCapacity.toBytes(); + this.buffer = requireNonNull(buffer, "buffer is null"); this.maxResponseSize = maxResponseSize; this.concurrentRequestMultiplier = concurrentRequestMultiplier; this.maxErrorDuration = maxErrorDuration; this.acknowledgePages = acknowledgePages; this.httpClient = httpClient; - this.scheduler = scheduler; + this.scheduledExecutor = scheduledExecutor; this.systemMemoryContext = systemMemoryContext; - this.maxBufferRetainedSizeInBytes = Long.MIN_VALUE; this.pageBufferClientCallbackExecutor = requireNonNull(pageBufferClientCallbackExecutor, "pageBufferClientCallbackExecutor is null"); + this.taskFailureListener = requireNonNull(taskFailureListener, "taskFailureListener is null"); } public ExchangeClientStatus getStatus() @@ -138,15 +122,18 @@ public ExchangeClientStatus getStatus() } List pageBufferClientStatus = pageBufferClientStatusBuilder.build(); synchronized (this) { - int bufferedPages = pageBuffer.size(); - if (bufferedPages > 0 && pageBuffer.peekLast() == NO_MORE_PAGES) { - bufferedPages--; - } - return new ExchangeClientStatus(bufferRetainedSizeInBytes, maxBufferRetainedSizeInBytes, averageBytesPerRequest, successfulRequests, bufferedPages, noMoreLocations, pageBufferClientStatus); + return new ExchangeClientStatus( + buffer.getRetainedSizeInBytes(), + buffer.getMaxRetainedSizeInBytes(), + averageBytesPerRequest, + successfulRequests, + buffer.getBufferedPageCount(), + noMoreLocations, + pageBufferClientStatus); } } - public synchronized void addLocation(URI location) + public synchronized void addLocation(TaskId taskId, URI location) { requireNonNull(location, "location is null"); @@ -162,7 +149,7 @@ public synchronized void addLocation(URI location) } checkState(!noMoreLocations, "No more locations already set"); - + buffer.addTask(taskId); HttpPageBufferClient client = new HttpPageBufferClient( selfAddress, httpClient, @@ -170,9 +157,10 @@ public synchronized void addLocation(URI location) maxResponseSize, maxErrorDuration, acknowledgePages, + taskId, location, new ExchangeClientCallback(), - scheduler, + scheduledExecutor, pageBufferClientCallbackExecutor); allClients.put(location, client); queuedClients.add(client); @@ -183,6 +171,7 @@ public synchronized void addLocation(URI location) public synchronized void noMoreLocations() { noMoreLocations = true; + buffer.noMoreTasks(); scheduleRequestIfNecessary(); } @@ -218,49 +207,25 @@ public SerializedPage pollPage() { assertNotHoldsLock(); - throwIfFailed(); - if (closed.get()) { return null; } - SerializedPage page = pageBuffer.poll(); + SerializedPage page = buffer.pollPage(); if (page == null) { return null; } - if (page == NO_MORE_PAGES) { - // mark client closed; close() will add the end marker - close(); - - notifyBlockedCallers(); - - // don't return end of stream marker - return null; - } - - synchronized (this) { - if (!closed.get()) { - bufferRetainedSizeInBytes -= page.getRetainedSizeInBytes(); - systemMemoryContext.setBytes(bufferRetainedSizeInBytes); - } - scheduleRequestIfNecessary(); - } + systemMemoryContext.setBytes(buffer.getRetainedSizeInBytes()); + scheduleRequestIfNecessary(); return page; } public boolean isFinished() { - throwIfFailed(); - // For this to works, locations must never be added after is closed is set - return isClosed() && completedClients.size() == allClients.size(); - } - - public boolean isClosed() - { - return closed.get(); + return buffer.isFinished() && completedClients.size() == allClients.size(); } @Override @@ -273,34 +238,17 @@ public synchronized void close() for (HttpPageBufferClient client : allClients.values()) { closeQuietly(client); } - pageBuffer.clear(); + buffer.close(); systemMemoryContext.setBytes(0); - bufferRetainedSizeInBytes = 0; - if (pageBuffer.peekLast() != NO_MORE_PAGES) { - checkState(pageBuffer.add(NO_MORE_PAGES), "Could not add no more pages marker"); - } - notifyBlockedCallers(); } private synchronized void scheduleRequestIfNecessary() { - if (isFinished() || isFailed()) { + if (isFinished()) { return; } - // if finished, add the end marker - if (noMoreLocations && completedClients.size() == allClients.size()) { - if (pageBuffer.peekLast() != NO_MORE_PAGES) { - checkState(pageBuffer.add(NO_MORE_PAGES), "Could not add no more pages marker"); - } - if (pageBuffer.peek() == NO_MORE_PAGES) { - close(); - } - notifyBlockedCallers(); - return; - } - - long neededBytes = bufferCapacity - bufferRetainedSizeInBytes; + long neededBytes = buffer.getRemainingCapacityInBytes(); if (neededBytes <= 0) { return; } @@ -323,81 +271,40 @@ private synchronized void scheduleRequestIfNecessary() public ListenableFuture isBlocked() { - // Fast path pre-check - if (isClosed() || isFailed() || pageBuffer.peek() != null) { - return NOT_BLOCKED; - } - synchronized (this) { - // Recheck after acquiring the lock - if (isClosed() || isFailed() || pageBuffer.peek() != null) { - return NOT_BLOCKED; - } - SettableFuture future = SettableFuture.create(); - blockedCallers.add(future); - return future; - } + return buffer.isBlocked(); } - private boolean addPages(List pages) + private boolean addPages(HttpPageBufferClient client, List pages) { + checkState(!completedClients.contains(client), "client is already marked as completed"); // Compute stats before acquiring the lock - long pagesRetainedSizeInBytes = 0; long responseSize = 0; for (SerializedPage page : pages) { - pagesRetainedSizeInBytes += page.getRetainedSizeInBytes(); responseSize += page.getSizeInBytes(); } - List> notify = ImmutableList.of(); synchronized (this) { - if (isClosed() || isFailed()) { + if (closed.get() || buffer.isFinished()) { return false; } - if (!pages.isEmpty()) { - pageBuffer.addAll(pages); - - bufferRetainedSizeInBytes += pagesRetainedSizeInBytes; - maxBufferRetainedSizeInBytes = Math.max(maxBufferRetainedSizeInBytes, bufferRetainedSizeInBytes); - systemMemoryContext.setBytes(bufferRetainedSizeInBytes); - - // Notify pending listeners that a page has been added - notify = ImmutableList.copyOf(blockedCallers); - blockedCallers.clear(); - } - successfulRequests++; // AVG_n = AVG_(n-1) * (n-1)/n + VALUE_n / n averageBytesPerRequest = (long) (1.0 * averageBytesPerRequest * (successfulRequests - 1) / successfulRequests + responseSize / successfulRequests); } - // Trigger notifications after releasing the lock - notifyListeners(notify); - - return true; - } - - private void notifyBlockedCallers() - { - List> callers; - synchronized (this) { - callers = ImmutableList.copyOf(blockedCallers); - blockedCallers.clear(); + // add pages outside of the lock + if (!pages.isEmpty()) { + buffer.addPages(client.getRemoteTaskId(), pages); + systemMemoryContext.setBytes(buffer.getRetainedSizeInBytes()); } - notifyListeners(callers); - } - private void notifyListeners(List> blockedCallers) - { - for (SettableFuture blockedCaller : blockedCallers) { - // Notify callers in a separate thread to avoid callbacks while holding a lock - scheduler.execute(() -> blockedCaller.set(null)); - } + return true; } private synchronized void requestComplete(HttpPageBufferClient client) { - if (!queuedClients.contains(client)) { + if (!completedClients.contains(client) && !queuedClients.contains(client)) { queuedClients.add(client); } scheduleRequestIfNecessary(); @@ -406,32 +313,21 @@ private synchronized void requestComplete(HttpPageBufferClient client) private synchronized void clientFinished(HttpPageBufferClient client) { requireNonNull(client, "client is null"); - completedClients.add(client); - scheduleRequestIfNecessary(); - } - - private synchronized void clientFailed(Throwable cause) - { - // TODO: properly handle the failed vs closed state - // it is important not to treat failures as a successful close - if (!isClosed()) { - failure.compareAndSet(null, cause); - notifyBlockedCallers(); + if (completedClients.add(client)) { + buffer.taskFinished(client.getRemoteTaskId()); } + scheduleRequestIfNecessary(); } - private boolean isFailed() - { - return failure.get() != null; - } - - private void throwIfFailed() + private synchronized void clientFailed(HttpPageBufferClient client, Throwable cause) { - Throwable t = failure.get(); - if (t != null) { - throwIfUnchecked(t); - throw new RuntimeException(t); + requireNonNull(client, "client is null"); + if (completedClients.add(client)) { + buffer.taskFailed(client.getRemoteTaskId(), cause); + scheduledExecutor.execute(() -> taskFailureListener.onTaskFailed(client.getRemoteTaskId(), cause)); + closeQuietly(client); } + scheduleRequestIfNecessary(); } private class ExchangeClientCallback @@ -442,7 +338,7 @@ public boolean addPages(HttpPageBufferClient client, List pages) { requireNonNull(client, "client is null"); requireNonNull(pages, "pages is null"); - return ExchangeClient.this.addPages(pages); + return ExchangeClient.this.addPages(client, pages); } @Override @@ -463,7 +359,7 @@ public void clientFailed(HttpPageBufferClient client, Throwable cause) { requireNonNull(client, "client is null"); requireNonNull(cause, "cause is null"); - ExchangeClient.this.clientFailed(cause); + ExchangeClient.this.clientFailed(client, cause); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/ExchangeClientBuffer.java b/core/trino-main/src/main/java/io/trino/operator/ExchangeClientBuffer.java new file mode 100644 index 000000000000..7990a63a510b --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/ExchangeClientBuffer.java @@ -0,0 +1,57 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator; + +import com.google.common.util.concurrent.ListenableFuture; +import io.trino.execution.TaskId; +import io.trino.execution.buffer.SerializedPage; + +import java.io.Closeable; +import java.util.List; + +public interface ExchangeClientBuffer + extends Closeable +{ + /** + * This method may be called by multiple independent client concurrently. + * Implementations must ensure the cancellation of a future by one of the clients + * doesn't cancel futures obtained by other clients. + */ + ListenableFuture isBlocked(); + + SerializedPage pollPage(); + + void addTask(TaskId taskId); + + void addPages(TaskId taskId, List pages); + + void taskFinished(TaskId taskId); + + void taskFailed(TaskId taskId, Throwable t); + + void noMoreTasks(); + + boolean isFinished(); + + long getRemainingCapacityInBytes(); + + long getRetainedSizeInBytes(); + + long getMaxRetainedSizeInBytes(); + + int getBufferedPageCount(); + + @Override + void close(); +} diff --git a/core/trino-main/src/main/java/io/trino/operator/ExchangeClientFactory.java b/core/trino-main/src/main/java/io/trino/operator/ExchangeClientFactory.java index 34e766b9eab0..58eafe70cef9 100644 --- a/core/trino-main/src/main/java/io/trino/operator/ExchangeClientFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/ExchangeClientFactory.java @@ -20,6 +20,7 @@ import io.airlift.units.Duration; import io.trino.FeaturesConfig; import io.trino.FeaturesConfig.DataIntegrityVerification; +import io.trino.execution.TaskFailureListener; import io.trino.memory.context.LocalMemoryContext; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; @@ -122,12 +123,25 @@ public ThreadPoolExecutorMBean getExecutor() } @Override - public ExchangeClient get(LocalMemoryContext systemMemoryContext) + public ExchangeClient get(LocalMemoryContext systemMemoryContext, TaskFailureListener taskFailureListener, RetryPolicy retryPolicy) { + ExchangeClientBuffer buffer; + switch (retryPolicy) { + case TASK: + case QUERY: + buffer = new DeduplicationExchangeClientBuffer(scheduler, maxBufferedBytes, retryPolicy); + break; + case NONE: + buffer = new StreamingExchangeClientBuffer(scheduler, maxBufferedBytes); + break; + default: + throw new IllegalArgumentException("unexpected retry policy: " + retryPolicy); + } + return new ExchangeClient( nodeInfo.getExternalAddress(), dataIntegrityVerification, - maxBufferedBytes, + buffer, maxResponseSize, concurrentRequestMultiplier, maxErrorDuration, @@ -135,6 +149,7 @@ public ExchangeClient get(LocalMemoryContext systemMemoryContext) httpClient, scheduler, systemMemoryContext, - pageBufferClientCallbackExecutor); + pageBufferClientCallbackExecutor, + taskFailureListener); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/ExchangeClientSupplier.java b/core/trino-main/src/main/java/io/trino/operator/ExchangeClientSupplier.java index 54b9f7802bdf..74176f88a267 100644 --- a/core/trino-main/src/main/java/io/trino/operator/ExchangeClientSupplier.java +++ b/core/trino-main/src/main/java/io/trino/operator/ExchangeClientSupplier.java @@ -13,9 +13,10 @@ */ package io.trino.operator; +import io.trino.execution.TaskFailureListener; import io.trino.memory.context.LocalMemoryContext; public interface ExchangeClientSupplier { - ExchangeClient get(LocalMemoryContext systemMemoryContext); + ExchangeClient get(LocalMemoryContext systemMemoryContext, TaskFailureListener taskFailureListener, RetryPolicy retryPolicy); } diff --git a/core/trino-main/src/main/java/io/trino/operator/ExchangeOperator.java b/core/trino-main/src/main/java/io/trino/operator/ExchangeOperator.java index a3aab178f705..0893c08d2648 100644 --- a/core/trino-main/src/main/java/io/trino/operator/ExchangeOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/ExchangeOperator.java @@ -24,7 +24,6 @@ import io.trino.split.RemoteSplit; import io.trino.sql.planner.plan.PlanNodeId; -import java.net.URI; import java.util.Optional; import java.util.function.Supplier; @@ -44,6 +43,7 @@ public static class ExchangeOperatorFactory private final PlanNodeId sourceId; private final ExchangeClientSupplier exchangeClientSupplier; private final PagesSerdeFactory serdeFactory; + private final RetryPolicy retryPolicy; private ExchangeClient exchangeClient; private boolean closed; @@ -51,12 +51,14 @@ public ExchangeOperatorFactory( int operatorId, PlanNodeId sourceId, ExchangeClientSupplier exchangeClientSupplier, - PagesSerdeFactory serdeFactory) + PagesSerdeFactory serdeFactory, + RetryPolicy retryPolicy) { this.operatorId = operatorId; this.sourceId = sourceId; this.exchangeClientSupplier = exchangeClientSupplier; this.serdeFactory = serdeFactory; + this.retryPolicy = requireNonNull(retryPolicy, "retryPolicy is null"); } @Override @@ -69,9 +71,10 @@ public PlanNodeId getSourceId() public SourceOperator createOperator(DriverContext driverContext) { checkState(!closed, "Factory is already closed"); + TaskContext taskContext = driverContext.getPipelineContext().getTaskContext(); OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, sourceId, ExchangeOperator.class.getSimpleName()); if (exchangeClient == null) { - exchangeClient = exchangeClientSupplier.get(driverContext.getPipelineContext().localSystemMemoryContext()); + exchangeClient = exchangeClientSupplier.get(driverContext.getPipelineContext().localSystemMemoryContext(), taskContext::sourceTaskFailed, retryPolicy); } return new ExchangeOperator( @@ -120,8 +123,8 @@ public Supplier> addSplit(Split split) requireNonNull(split, "split is null"); checkArgument(split.getCatalogName().equals(REMOTE_CONNECTOR_ID), "split is not a remote split"); - URI location = ((RemoteSplit) split.getConnectorSplit()).getLocation(); - exchangeClient.addLocation(location); + RemoteSplit remoteSplit = (RemoteSplit) split.getConnectorSplit(); + exchangeClient.addLocation(remoteSplit.getTaskId(), remoteSplit.getLocation()); return Optional::empty; } diff --git a/core/trino-main/src/main/java/io/trino/operator/HttpPageBufferClient.java b/core/trino-main/src/main/java/io/trino/operator/HttpPageBufferClient.java index b6cee274f0ec..34cdcafa324c 100644 --- a/core/trino-main/src/main/java/io/trino/operator/HttpPageBufferClient.java +++ b/core/trino-main/src/main/java/io/trino/operator/HttpPageBufferClient.java @@ -32,6 +32,7 @@ import io.airlift.units.DataSize; import io.airlift.units.Duration; import io.trino.FeaturesConfig.DataIntegrityVerification; +import io.trino.execution.TaskId; import io.trino.execution.buffer.SerializedPage; import io.trino.server.remotetask.Backoff; import io.trino.spi.TrinoException; @@ -77,11 +78,13 @@ import static io.trino.server.InternalHeaders.TRINO_MAX_SIZE; import static io.trino.server.InternalHeaders.TRINO_PAGE_NEXT_TOKEN; import static io.trino.server.InternalHeaders.TRINO_PAGE_TOKEN; +import static io.trino.server.InternalHeaders.TRINO_TASK_FAILED; import static io.trino.server.InternalHeaders.TRINO_TASK_INSTANCE_ID; import static io.trino.server.PagesResponseWriter.SERIALIZED_PAGES_MAGIC; import static io.trino.spi.HostAddress.fromUri; import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static io.trino.spi.StandardErrorCode.REMOTE_BUFFER_CLOSE_FAILED; +import static io.trino.spi.StandardErrorCode.REMOTE_TASK_FAILED; import static io.trino.spi.StandardErrorCode.REMOTE_TASK_MISMATCH; import static io.trino.util.Failures.REMOTE_TASK_MISMATCH_ERROR; import static io.trino.util.Failures.WORKER_NODE_ERROR; @@ -121,9 +124,10 @@ public interface ClientCallback private final DataIntegrityVerification dataIntegrityVerification; private final DataSize maxResponseSize; private final boolean acknowledgePages; + private final TaskId remoteTaskId; private final URI location; private final ClientCallback clientCallback; - private final ScheduledExecutorService scheduler; + private final ScheduledExecutorService scheduledExecutor; private final Backoff backoff; @GuardedBy("this") @@ -160,9 +164,10 @@ public HttpPageBufferClient( DataSize maxResponseSize, Duration maxErrorDuration, boolean acknowledgePages, + TaskId remoteTaskId, URI location, ClientCallback clientCallback, - ScheduledExecutorService scheduler, + ScheduledExecutorService scheduledExecutor, Executor pageBufferClientCallbackExecutor) { this( @@ -172,9 +177,10 @@ public HttpPageBufferClient( maxResponseSize, maxErrorDuration, acknowledgePages, + remoteTaskId, location, clientCallback, - scheduler, + scheduledExecutor, Ticker.systemTicker(), pageBufferClientCallbackExecutor); } @@ -186,9 +192,10 @@ public HttpPageBufferClient( DataSize maxResponseSize, Duration maxErrorDuration, boolean acknowledgePages, + TaskId remoteTaskId, URI location, ClientCallback clientCallback, - ScheduledExecutorService scheduler, + ScheduledExecutorService scheduledExecutor, Ticker ticker, Executor pageBufferClientCallbackExecutor) { @@ -197,9 +204,10 @@ public HttpPageBufferClient( this.dataIntegrityVerification = requireNonNull(dataIntegrityVerification, "dataIntegrityVerification is null"); this.maxResponseSize = requireNonNull(maxResponseSize, "maxResponseSize is null"); this.acknowledgePages = acknowledgePages; + this.remoteTaskId = requireNonNull(remoteTaskId, "remoteTaskId is null"); this.location = requireNonNull(location, "location is null"); this.clientCallback = requireNonNull(clientCallback, "clientCallback is null"); - this.scheduler = requireNonNull(scheduler, "scheduler is null"); + this.scheduledExecutor = requireNonNull(scheduledExecutor, "scheduledExecutor is null"); this.pageBufferClientCallbackExecutor = requireNonNull(pageBufferClientCallbackExecutor, "pageBufferClientCallbackExecutor is null"); requireNonNull(maxErrorDuration, "maxErrorDuration is null"); requireNonNull(ticker, "ticker is null"); @@ -246,6 +254,11 @@ else if (completed) { httpRequestState); } + public TaskId getRemoteTaskId() + { + return remoteTaskId; + } + public synchronized boolean isRunning() { return future != null; @@ -289,7 +302,7 @@ public synchronized void scheduleRequest() backoff.startRequest(); long delayNanos = backoff.getBackoffDelayNanos(); - scheduler.schedule(() -> { + scheduledExecutor.schedule(() -> { try { initiateRequest(); } @@ -341,6 +354,10 @@ public void onSuccess(PagesResponse result) List pages; try { + if (result.isTaskFailed()) { + throw new TrinoException(REMOTE_TASK_FAILED, format("Remote task failed: %s", remoteTaskId)); + } + boolean shouldAcknowledge = false; synchronized (HttpPageBufferClient.this) { if (taskInstanceId == null) { @@ -615,7 +632,12 @@ public PagesResponse handle(Request request, Response response) // no content means no content was created within the wait period, but query is still ok // if job is finished, complete is set in the response if (response.getStatusCode() == HttpStatus.NO_CONTENT.code()) { - return createEmptyPagesResponse(getTaskInstanceId(response, uri), getToken(response, uri), getNextToken(response, uri), getComplete(response, uri)); + return createEmptyPagesResponse( + getTaskInstanceId(response, uri), + getToken(response, uri), + getNextToken(response, uri), + getComplete(response, uri), + getTaskFailed(response, uri)); } // otherwise we must have gotten an OK response, everything else is considered fatal @@ -651,6 +673,7 @@ public PagesResponse handle(Request request, Response response) long token = getToken(response, uri); long nextToken = getNextToken(response, uri); boolean complete = getComplete(response, uri); + boolean remoteTaskFailed = getTaskFailed(response, uri); try (SliceInput input = new InputStreamSliceInput(response.getInputStream())) { int magic = input.readInt(); @@ -662,7 +685,7 @@ public PagesResponse handle(Request request, Response response) List pages = ImmutableList.copyOf(readSerializedPages(input)); verifyChecksum(checksum, pages); checkState(pages.size() == pagesCount, "Wrong number of pages, expected %s, but read %s", pagesCount, pages.size()); - return createPagesResponse(taskInstanceId, token, nextToken, pages, complete); + return createPagesResponse(taskInstanceId, token, nextToken, pages, complete, remoteTaskFailed); } catch (IOException e) { throw new RuntimeException(e); @@ -724,6 +747,15 @@ private static boolean getComplete(Response response, URI uri) return Boolean.parseBoolean(bufferComplete); } + private static boolean getTaskFailed(Response response, URI uri) + { + String taskFailed = response.getHeader(TRINO_TASK_FAILED); + if (taskFailed == null) { + throw new PageTransportErrorException(fromUri(uri), format("Expected %s header", TRINO_TASK_FAILED)); + } + return Boolean.parseBoolean(taskFailed); + } + private static boolean mediaTypeMatches(String value, MediaType range) { try { @@ -737,14 +769,14 @@ private static boolean mediaTypeMatches(String value, MediaType range) public static class PagesResponse { - public static PagesResponse createPagesResponse(String taskInstanceId, long token, long nextToken, Iterable pages, boolean complete) + public static PagesResponse createPagesResponse(String taskInstanceId, long token, long nextToken, Iterable pages, boolean complete, boolean taskFailed) { - return new PagesResponse(taskInstanceId, token, nextToken, pages, complete); + return new PagesResponse(taskInstanceId, token, nextToken, pages, complete, taskFailed); } - public static PagesResponse createEmptyPagesResponse(String taskInstanceId, long token, long nextToken, boolean complete) + public static PagesResponse createEmptyPagesResponse(String taskInstanceId, long token, long nextToken, boolean complete, boolean taskFailed) { - return new PagesResponse(taskInstanceId, token, nextToken, ImmutableList.of(), complete); + return new PagesResponse(taskInstanceId, token, nextToken, ImmutableList.of(), complete, taskFailed); } private final String taskInstanceId; @@ -752,14 +784,16 @@ public static PagesResponse createEmptyPagesResponse(String taskInstanceId, long private final long nextToken; private final List pages; private final boolean clientComplete; + private final boolean taskFailed; - private PagesResponse(String taskInstanceId, long token, long nextToken, Iterable pages, boolean clientComplete) + private PagesResponse(String taskInstanceId, long token, long nextToken, Iterable pages, boolean clientComplete, boolean taskFailed) { this.taskInstanceId = taskInstanceId; this.token = token; this.nextToken = nextToken; this.pages = ImmutableList.copyOf(pages); this.clientComplete = clientComplete; + this.taskFailed = taskFailed; } public long getToken() @@ -787,6 +821,11 @@ public String getTaskInstanceId() return taskInstanceId; } + public boolean isTaskFailed() + { + return taskFailed; + } + @Override public String toString() { @@ -795,6 +834,7 @@ public String toString() .add("nextToken", nextToken) .add("pagesSize", pages.size()) .add("clientComplete", clientComplete) + .add("taskFailed", taskFailed) .toString(); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/MergeOperator.java b/core/trino-main/src/main/java/io/trino/operator/MergeOperator.java index 25e3c22b8a5d..8a32faae1469 100644 --- a/core/trino-main/src/main/java/io/trino/operator/MergeOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/MergeOperator.java @@ -29,7 +29,6 @@ import java.io.IOException; import java.io.UncheckedIOException; -import java.net.URI; import java.util.ArrayList; import java.util.List; import java.util.Optional; @@ -158,9 +157,10 @@ public Supplier> addSplit(Split split) checkArgument(split.getConnectorSplit() instanceof RemoteSplit, "split is not a remote split"); checkState(!blockedOnSplits.isDone(), "noMoreSplits has been called already"); - URI location = ((RemoteSplit) split.getConnectorSplit()).getLocation(); - ExchangeClient exchangeClient = closer.register(exchangeClientSupplier.get(operatorContext.localSystemMemoryContext())); - exchangeClient.addLocation(location); + TaskContext taskContext = operatorContext.getDriverContext().getPipelineContext().getTaskContext(); + ExchangeClient exchangeClient = closer.register(exchangeClientSupplier.get(operatorContext.localSystemMemoryContext(), taskContext::sourceTaskFailed, RetryPolicy.NONE)); + RemoteSplit remoteSplit = (RemoteSplit) split.getConnectorSplit(); + exchangeClient.addLocation(remoteSplit.getTaskId(), remoteSplit.getLocation()); exchangeClient.noMoreLocations(); pageProducers.add(exchangeClient.pages() .map(serializedPage -> { diff --git a/core/trino-main/src/main/java/io/trino/operator/RetryPolicy.java b/core/trino-main/src/main/java/io/trino/operator/RetryPolicy.java new file mode 100644 index 000000000000..d28f5138137b --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/RetryPolicy.java @@ -0,0 +1,22 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator; + +public enum RetryPolicy +{ + TASK, + QUERY, + NONE, + /**/; +} diff --git a/core/trino-main/src/main/java/io/trino/operator/StreamingExchangeClientBuffer.java b/core/trino-main/src/main/java/io/trino/operator/StreamingExchangeClientBuffer.java new file mode 100644 index 000000000000..c027ff395a41 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/StreamingExchangeClientBuffer.java @@ -0,0 +1,221 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator; + +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; +import io.airlift.units.DataSize; +import io.trino.execution.TaskId; +import io.trino.execution.buffer.SerializedPage; +import io.trino.spi.TrinoException; + +import javax.annotation.concurrent.GuardedBy; + +import java.util.ArrayDeque; +import java.util.HashSet; +import java.util.List; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.Executor; + +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Throwables.throwIfUnchecked; +import static com.google.common.util.concurrent.Futures.nonCancellationPropagating; +import static io.trino.spi.StandardErrorCode.REMOTE_TASK_FAILED; +import static java.lang.Math.max; +import static java.util.Objects.requireNonNull; + +public class StreamingExchangeClientBuffer + implements ExchangeClientBuffer +{ + private final Executor executor; + private final long bufferCapacityInBytes; + + @GuardedBy("this") + private final Queue bufferedPages = new ArrayDeque<>(); + @GuardedBy("this") + private volatile long bufferRetainedSizeInBytes; + @GuardedBy("this") + private volatile long maxBufferRetainedSizeInBytes; + @GuardedBy("this") + private volatile SettableFuture blocked = SettableFuture.create(); + @GuardedBy("this") + private final Set activeTasks = new HashSet<>(); + @GuardedBy("this") + private boolean noMoreTasks; + @GuardedBy("this") + private Throwable failure; + @GuardedBy("this") + private boolean closed; + + public StreamingExchangeClientBuffer(Executor executor, DataSize bufferCapacity) + { + this.executor = requireNonNull(executor, "executor is null"); + this.bufferCapacityInBytes = requireNonNull(bufferCapacity, "bufferCapacity is null").toBytes(); + } + + @Override + public ListenableFuture isBlocked() + { + return nonCancellationPropagating(blocked); + } + + @Override + public synchronized SerializedPage pollPage() + { + throwIfFailed(); + + if (closed) { + return null; + } + SerializedPage page = bufferedPages.poll(); + if (page != null) { + bufferRetainedSizeInBytes -= page.getRetainedSizeInBytes(); + checkState(bufferRetainedSizeInBytes >= 0, "unexpected bufferRetainedSizeInBytes: %s", bufferRetainedSizeInBytes); + } + // if buffer is empty block future calls + if (bufferedPages.isEmpty() && !isFinished() && blocked.isDone()) { + blocked = SettableFuture.create(); + } + return page; + } + + @Override + public synchronized void addTask(TaskId taskId) + { + if (closed) { + return; + } + checkState(!noMoreTasks, "no more tasks are expected"); + activeTasks.add(taskId); + } + + @Override + public void addPages(TaskId taskId, List pages) + { + long pagesRetainedSizeInBytes = 0; + for (SerializedPage page : pages) { + pagesRetainedSizeInBytes += page.getRetainedSizeInBytes(); + } + synchronized (this) { + if (closed) { + return; + } + checkState(activeTasks.contains(taskId), "taskId is not active: %s", taskId); + bufferedPages.addAll(pages); + bufferRetainedSizeInBytes += pagesRetainedSizeInBytes; + maxBufferRetainedSizeInBytes = max(maxBufferRetainedSizeInBytes, bufferRetainedSizeInBytes); + unblockIfNecessary(blocked); + } + } + + @Override + public synchronized void taskFinished(TaskId taskId) + { + if (closed) { + return; + } + checkState(activeTasks.contains(taskId), "taskId not registered: %s", taskId); + activeTasks.remove(taskId); + if (noMoreTasks && activeTasks.isEmpty() && !blocked.isDone()) { + unblockIfNecessary(blocked); + } + } + + @Override + public synchronized void taskFailed(TaskId taskId, Throwable t) + { + if (closed) { + return; + } + checkState(activeTasks.contains(taskId), "taskId not registered: %s", taskId); + + if (t instanceof TrinoException && ((TrinoException) t).getErrorCode() == REMOTE_TASK_FAILED.toErrorCode()) { + // let coordinator handle this + return; + } + + failure = t; + activeTasks.remove(taskId); + unblockIfNecessary(blocked); + } + + @Override + public synchronized void noMoreTasks() + { + noMoreTasks = true; + if (activeTasks.isEmpty() && !blocked.isDone()) { + unblockIfNecessary(blocked); + } + } + + @Override + public synchronized boolean isFinished() + { + return failure != null || (noMoreTasks && activeTasks.isEmpty() && bufferedPages.isEmpty()); + } + + @Override + public long getRemainingCapacityInBytes() + { + return max(bufferCapacityInBytes - bufferRetainedSizeInBytes, 0); + } + + @Override + public long getRetainedSizeInBytes() + { + return bufferRetainedSizeInBytes; + } + + @Override + public long getMaxRetainedSizeInBytes() + { + return maxBufferRetainedSizeInBytes; + } + + @Override + public synchronized int getBufferedPageCount() + { + return bufferedPages.size(); + } + + @Override + public synchronized void close() + { + if (closed) { + return; + } + bufferedPages.clear(); + bufferRetainedSizeInBytes = 0; + activeTasks.clear(); + noMoreTasks = true; + closed = true; + unblockIfNecessary(blocked); + } + + private void unblockIfNecessary(SettableFuture blocked) + { + if (!blocked.isDone()) { + executor.execute(() -> blocked.set(null)); + } + } + + private synchronized void throwIfFailed() + { + if (failure != null) { + throwIfUnchecked(failure); + throw new RuntimeException(failure); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/TaskContext.java b/core/trino-main/src/main/java/io/trino/operator/TaskContext.java index cb66ff328968..fd08d73d9f96 100644 --- a/core/trino-main/src/main/java/io/trino/operator/TaskContext.java +++ b/core/trino-main/src/main/java/io/trino/operator/TaskContext.java @@ -613,4 +613,9 @@ public void addDynamicFilter(Map dynamicFilterDomains) { localDynamicFiltersCollector.collectDynamicFilterDomains(dynamicFilterDomains); } + + public void sourceTaskFailed(TaskId taskId, Throwable failure) + { + taskStateMachine.sourceTaskFailed(taskId, failure); + } } diff --git a/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java b/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java index 427cc4487c5d..d4a58e0d227a 100644 --- a/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java +++ b/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java @@ -90,6 +90,7 @@ import io.trino.sql.planner.PlanOptimizers; import io.trino.sql.planner.PlanOptimizersFactory; import io.trino.sql.planner.RuleStatsRecorder; +import io.trino.sql.planner.SplitSourceFactory; import io.trino.sql.rewrite.DescribeInputRewrite; import io.trino.sql.rewrite.DescribeOutputRewrite; import io.trino.sql.rewrite.ExplainRewrite; @@ -277,6 +278,7 @@ protected void setup(Binder binder) newExporter(binder).export(QueryExecutionMBean.class) .as(generator -> generator.generatedNameOf(QueryExecution.class)); + binder.bind(SplitSourceFactory.class).in(Scopes.SINGLETON); binder.bind(SplitSchedulerStats.class).in(Scopes.SINGLETON); newExporter(binder).export(SplitSchedulerStats.class).withGeneratedName(); diff --git a/core/trino-main/src/main/java/io/trino/server/InternalHeaders.java b/core/trino-main/src/main/java/io/trino/server/InternalHeaders.java index d9e333e40deb..bad1681be03b 100644 --- a/core/trino-main/src/main/java/io/trino/server/InternalHeaders.java +++ b/core/trino-main/src/main/java/io/trino/server/InternalHeaders.java @@ -22,6 +22,7 @@ public final class InternalHeaders public static final String TRINO_PAGE_TOKEN = "X-Trino-Page-Sequence-Id"; public static final String TRINO_PAGE_NEXT_TOKEN = "X-Trino-Page-End-Sequence-Id"; public static final String TRINO_BUFFER_COMPLETE = "X-Trino-Buffer-Complete"; + public static final String TRINO_TASK_FAILED = "X-Trino-Task-Failed"; private InternalHeaders() {} } diff --git a/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java b/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java index 846a8600ffee..4bd96f692b31 100644 --- a/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java +++ b/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java @@ -38,6 +38,8 @@ import io.trino.event.SplitMonitor; import io.trino.execution.DynamicFilterConfig; import io.trino.execution.ExplainAnalyzeContext; +import io.trino.execution.FailureInjectionConfig; +import io.trino.execution.FailureInjector; import io.trino.execution.LocationFactory; import io.trino.execution.MemoryRevokingScheduler; import io.trino.execution.NodeTaskMap; @@ -269,6 +271,8 @@ protected void setup(Binder binder) new TopologyAwareNodeSelectorModule())); // task execution + configBinder(binder).bindConfig(FailureInjectionConfig.class); + binder.bind(FailureInjector.class).in(Scopes.SINGLETON); jaxrsBinder(binder).bind(TaskResource.class); newExporter(binder).export(TaskResource.class).withGeneratedName(); jaxrsBinder(binder).bind(TaskExecutorResource.class); diff --git a/core/trino-main/src/main/java/io/trino/server/TaskResource.java b/core/trino-main/src/main/java/io/trino/server/TaskResource.java index 3154f92553ef..c932bf35876b 100644 --- a/core/trino-main/src/main/java/io/trino/server/TaskResource.java +++ b/core/trino-main/src/main/java/io/trino/server/TaskResource.java @@ -18,14 +18,17 @@ import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import io.airlift.concurrent.BoundedExecutor; +import io.airlift.log.Logger; import io.airlift.stats.TimeStat; import io.airlift.units.DataSize; import io.airlift.units.Duration; import io.trino.Session; -import io.trino.execution.DynamicFiltersCollector.VersionedDynamicFilterDomains; +import io.trino.execution.FailureInjector; +import io.trino.execution.FailureInjector.InjectedFailure; import io.trino.execution.TaskId; import io.trino.execution.TaskInfo; import io.trino.execution.TaskManager; +import io.trino.execution.TaskState; import io.trino.execution.TaskStatus; import io.trino.execution.buffer.BufferResult; import io.trino.execution.buffer.OutputBuffers.OutputBufferId; @@ -57,6 +60,7 @@ import javax.ws.rs.core.UriInfo; import java.util.List; +import java.util.Optional; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ThreadLocalRandom; @@ -72,6 +76,7 @@ import static io.trino.server.InternalHeaders.TRINO_MAX_WAIT; import static io.trino.server.InternalHeaders.TRINO_PAGE_NEXT_TOKEN; import static io.trino.server.InternalHeaders.TRINO_PAGE_TOKEN; +import static io.trino.server.InternalHeaders.TRINO_TASK_FAILED; import static io.trino.server.InternalHeaders.TRINO_TASK_INSTANCE_ID; import static io.trino.server.security.ResourceSecurity.AccessType.INTERNAL_ONLY; import static java.util.Objects.requireNonNull; @@ -84,6 +89,8 @@ @Path("/v1/task") public class TaskResource { + private static final Logger log = Logger.get(TaskResource.class); + private static final Duration ADDITIONAL_WAIT_TIME = new Duration(5, SECONDS); private static final Duration DEFAULT_MAX_WAIT_TIME = new Duration(2, SECONDS); @@ -91,6 +98,7 @@ public class TaskResource private final SessionPropertyManager sessionPropertyManager; private final Executor responseExecutor; private final ScheduledExecutorService timeoutExecutor; + private final FailureInjector failureInjector; private final TimeStat readFromOutputBufferTime = new TimeStat(); private final TimeStat resultsRequestTime = new TimeStat(); @@ -99,12 +107,14 @@ public TaskResource( TaskManager taskManager, SessionPropertyManager sessionPropertyManager, @ForAsyncHttp BoundedExecutor responseExecutor, - @ForAsyncHttp ScheduledExecutorService timeoutExecutor) + @ForAsyncHttp ScheduledExecutorService timeoutExecutor, + FailureInjector failureInjector) { this.taskManager = requireNonNull(taskManager, "taskManager is null"); this.sessionPropertyManager = requireNonNull(sessionPropertyManager, "sessionPropertyManager is null"); this.responseExecutor = requireNonNull(responseExecutor, "responseExecutor is null"); this.timeoutExecutor = requireNonNull(timeoutExecutor, "timeoutExecutor is null"); + this.failureInjector = requireNonNull(failureInjector, "failureInjector is null"); } @ResourceSecurity(INTERNAL_ONLY) @@ -124,11 +134,20 @@ public List getAllTaskInfo(@Context UriInfo uriInfo) @Path("{taskId}") @Consumes(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON) - public Response createOrUpdateTask(@PathParam("taskId") TaskId taskId, TaskUpdateRequest taskUpdateRequest, @Context UriInfo uriInfo) + public void createOrUpdateTask( + @PathParam("taskId") TaskId taskId, + TaskUpdateRequest taskUpdateRequest, + @Context UriInfo uriInfo, + @Suspended AsyncResponse asyncResponse) { requireNonNull(taskUpdateRequest, "taskUpdateRequest is null"); Session session = taskUpdateRequest.getSession().toSession(sessionPropertyManager, taskUpdateRequest.getExtraCredentials()); + + if (injectFailure(session.getTraceToken(), taskId, RequestType.CREATE_OR_UPDATE_TASK, asyncResponse)) { + return; + } + TaskInfo taskInfo = taskManager.updateTask(session, taskId, taskUpdateRequest.getFragment(), @@ -140,7 +159,7 @@ public Response createOrUpdateTask(@PathParam("taskId") TaskId taskId, TaskUpdat taskInfo = taskInfo.summarize(); } - return Response.ok().entity(taskInfo).build(); + asyncResponse.resume(Response.ok().entity(taskInfo).build()); } @ResourceSecurity(INTERNAL_ONLY) @@ -156,6 +175,10 @@ public void getTaskInfo( { requireNonNull(taskId, "taskId is null"); + if (injectFailure(taskManager.getTraceToken(taskId), taskId, RequestType.GET_TASK_INFO, asyncResponse)) { + return; + } + if (currentVersion == null || maxWait == null) { TaskInfo taskInfo = taskManager.getTaskInfo(taskId); if (shouldSummarize(uriInfo)) { @@ -195,6 +218,10 @@ public void getTaskStatus( { requireNonNull(taskId, "taskId is null"); + if (injectFailure(taskManager.getTraceToken(taskId), taskId, RequestType.GET_TASK_STATUS, asyncResponse)) { + return; + } + if (currentVersion == null || maxWait == null) { TaskStatus taskStatus = taskManager.getTaskStatus(taskId); asyncResponse.resume(taskStatus); @@ -221,14 +248,20 @@ public void getTaskStatus( @GET @Path("{taskId}/dynamicfilters") @Produces(MediaType.APPLICATION_JSON) - public VersionedDynamicFilterDomains acknowledgeAndGetNewDynamicFilterDomains( + public void acknowledgeAndGetNewDynamicFilterDomains( @PathParam("taskId") TaskId taskId, @HeaderParam(TRINO_CURRENT_VERSION) Long currentDynamicFiltersVersion, - @Context UriInfo uriInfo) + @Context UriInfo uriInfo, + @Suspended AsyncResponse asyncResponse) { requireNonNull(taskId, "taskId is null"); requireNonNull(currentDynamicFiltersVersion, "currentDynamicFiltersVersion is null"); - return taskManager.acknowledgeAndGetNewDynamicFilterDomains(taskId, currentDynamicFiltersVersion); + + if (injectFailure(taskManager.getTraceToken(taskId), taskId, RequestType.ACKNOWLEDGE_AND_GET_NEW_DYNAMIC_FILTER_DOMAINS, asyncResponse)) { + return; + } + + asyncResponse.resume(taskManager.acknowledgeAndGetNewDynamicFilterDomains(taskId, currentDynamicFiltersVersion)); } @ResourceSecurity(INTERNAL_ONLY) @@ -270,6 +303,13 @@ public void getResults( requireNonNull(taskId, "taskId is null"); requireNonNull(bufferId, "bufferId is null"); + if (injectFailure(taskManager.getTraceToken(taskId), taskId, RequestType.GET_RESULTS, asyncResponse)) { + return; + } + + TaskState state = taskManager.getTaskStatus(taskId).getState(); + boolean taskFailed = state == TaskState.ABORTED || state == TaskState.FAILED; + long start = System.nanoTime(); ListenableFuture bufferResultFuture = taskManager.getTaskResults(taskId, bufferId, token, maxSize); Duration waitTime = randomizeWaitTime(DEFAULT_MAX_WAIT_TIME); @@ -298,6 +338,7 @@ public void getResults( .header(TRINO_PAGE_TOKEN, result.getToken()) .header(TRINO_PAGE_NEXT_TOKEN, result.getNextToken()) .header(TRINO_BUFFER_COMPLETE, result.isBufferComplete()) + .header(TRINO_TASK_FAILED, taskFailed) .build(); }, directExecutor()); @@ -310,6 +351,7 @@ public void getResults( .header(TRINO_PAGE_TOKEN, token) .header(TRINO_PAGE_NEXT_TOKEN, token) .header(TRINO_BUFFER_COMPLETE, false) + .header(TRINO_TASK_FAILED, taskFailed) .build()); responseFuture.addListener(() -> readFromOutputBufferTime.add(Duration.nanosSince(start)), directExecutor()); @@ -333,13 +375,105 @@ public void acknowledgeResults( @ResourceSecurity(INTERNAL_ONLY) @DELETE @Path("{taskId}/results/{bufferId}") - @Produces(MediaType.APPLICATION_JSON) - public void abortResults(@PathParam("taskId") TaskId taskId, @PathParam("bufferId") OutputBufferId bufferId, @Context UriInfo uriInfo) + public void abortResults( + @PathParam("taskId") TaskId taskId, + @PathParam("bufferId") OutputBufferId bufferId, + @Context UriInfo uriInfo, + @Suspended AsyncResponse asyncResponse) { requireNonNull(taskId, "taskId is null"); requireNonNull(bufferId, "bufferId is null"); + if (injectFailure(taskManager.getTraceToken(taskId), taskId, RequestType.ABORT_RESULTS, asyncResponse)) { + return; + } + taskManager.abortTaskResults(taskId, bufferId); + asyncResponse.resume(Response.noContent().build()); + } + + private boolean injectFailure( + Optional traceToken, + TaskId taskId, + RequestType requestType, + AsyncResponse asyncResponse) + { + if (traceToken.isEmpty()) { + return false; + } + + Optional injectedFailure = failureInjector.getInjectedFailure( + traceToken.get(), + taskId.getStageId().getId(), + taskId.getPartitionId(), + taskId.getAttemptId()); + + if (injectedFailure.isEmpty()) { + return false; + } + + InjectedFailure failure = injectedFailure.get(); + Duration timeout = failureInjector.getRequestTimeout(); + switch (failure.getInjectedFailureType()) { + case TASK_MANAGEMENT_REQUEST_FAILURE: + if (requestType.isTaskManagement()) { + log.info("Failing %s request for task %s", requestType, taskId); + asyncResponse.resume(Response.serverError().build()); + return true; + } + break; + case TASK_MANAGEMENT_REQUEST_TIMEOUT: + if (requestType.isTaskManagement()) { + log.info("Timing out %s request for task %s", requestType, taskId); + asyncResponse.setTimeout(timeout.toMillis(), MILLISECONDS); + return true; + } + break; + case TASK_GET_RESULTS_REQUEST_FAILURE: + if (!requestType.isTaskManagement()) { + log.info("Failing %s request for task %s", requestType, taskId); + asyncResponse.resume(Response.serverError().build()); + return true; + } + break; + case TASK_GET_RESULTS_REQUEST_TIMEOUT: + if (!requestType.isTaskManagement()) { + log.info("Timing out %s request for task %s", requestType, taskId); + asyncResponse.setTimeout(timeout.toMillis(), MILLISECONDS); + return true; + } + break; + case TASK_FAILURE: + log.info("Injecting failure for task %s at %s", taskId, requestType); + taskManager.failTask(taskId, injectedFailure.get().getTaskFailureException()); + break; + default: + throw new IllegalArgumentException("unexpected failure type: " + failure.getInjectedFailureType()); + } + + return false; + } + + private enum RequestType + { + CREATE_OR_UPDATE_TASK(true), + GET_TASK_INFO(true), + GET_TASK_STATUS(true), + ACKNOWLEDGE_AND_GET_NEW_DYNAMIC_FILTER_DOMAINS(true), + GET_RESULTS(false), + ABORT_RESULTS(false); + + private final boolean taskManagement; + + RequestType(boolean taskManagement) + { + this.taskManagement = taskManagement; + } + + public boolean isTaskManagement() + { + return taskManagement; + } } @Managed diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/ExecutingStatementResource.java b/core/trino-main/src/main/java/io/trino/server/protocol/ExecutingStatementResource.java index a33a212ca529..7063413c618b 100644 --- a/core/trino-main/src/main/java/io/trino/server/protocol/ExecutingStatementResource.java +++ b/core/trino-main/src/main/java/io/trino/server/protocol/ExecutingStatementResource.java @@ -24,8 +24,6 @@ import io.trino.client.ProtocolHeaders; import io.trino.client.QueryResults; import io.trino.execution.QueryManager; -import io.trino.memory.context.SimpleLocalMemoryContext; -import io.trino.operator.ExchangeClient; import io.trino.operator.ExchangeClientSupplier; import io.trino.server.ForStatementResource; import io.trino.server.ServerConfig; @@ -61,7 +59,6 @@ import static io.airlift.concurrent.Threads.threadsNamed; import static io.airlift.jaxrs.AsyncResponseHandler.bindAsyncResponse; import static io.airlift.units.DataSize.Unit.MEGABYTE; -import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; import static io.trino.server.protocol.Slug.Context.EXECUTING_QUERY; import static io.trino.server.security.ResourceSecurity.AccessType.PUBLIC; import static java.nio.charset.StandardCharsets.UTF_8; @@ -181,18 +178,15 @@ protected Query getQuery(QueryId queryId, String slug, long token) throw queryNotFound(); } - query = queries.computeIfAbsent(queryId, id -> { - ExchangeClient exchangeClient = exchangeClientSupplier.get(new SimpleLocalMemoryContext(newSimpleAggregatedMemoryContext(), ExecutingStatementResource.class.getSimpleName())); - return Query.create( - session, - querySlug, - queryManager, - queryInfoUrlFactory.getQueryInfoUrl(queryId), - exchangeClient, - responseExecutor, - timeoutExecutor, - blockEncodingSerde); - }); + query = queries.computeIfAbsent(queryId, id -> Query.create( + session, + querySlug, + queryManager, + queryInfoUrlFactory.getQueryInfoUrl(queryId), + exchangeClientSupplier, + responseExecutor, + timeoutExecutor, + blockEncodingSerde)); return query; } diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/Query.java b/core/trino-main/src/main/java/io/trino/server/protocol/Query.java index d3251d23389f..9e6c53d70a96 100644 --- a/core/trino-main/src/main/java/io/trino/server/protocol/Query.java +++ b/core/trino-main/src/main/java/io/trino/server/protocol/Query.java @@ -48,7 +48,9 @@ import io.trino.execution.buffer.PagesSerde; import io.trino.execution.buffer.PagesSerdeFactory; import io.trino.execution.buffer.SerializedPage; +import io.trino.memory.context.SimpleLocalMemoryContext; import io.trino.operator.ExchangeClient; +import io.trino.operator.ExchangeClientSupplier; import io.trino.spi.ErrorCode; import io.trino.spi.Page; import io.trino.spi.QueryId; @@ -97,8 +99,10 @@ import static com.google.common.util.concurrent.Futures.immediateVoidFuture; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static io.airlift.concurrent.MoreFutures.addTimeout; +import static io.trino.SystemSessionProperties.getRetryPolicy; import static io.trino.SystemSessionProperties.isExchangeCompressionEnabled; import static io.trino.execution.QueryState.FAILED; +import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; import static io.trino.server.protocol.QueryInfoUrlFactory.getQueryInfoUri; import static io.trino.server.protocol.QueryResultRows.queryResultRowsBuilder; import static io.trino.server.protocol.Slug.Context.EXECUTING_QUERY; @@ -190,11 +194,16 @@ public static Query create( Slug slug, QueryManager queryManager, Optional queryInfoUrl, - ExchangeClient exchangeClient, + ExchangeClientSupplier exchangeClientSupplier, Executor dataProcessorExecutor, ScheduledExecutorService timeoutExecutor, BlockEncodingSerde blockEncodingSerde) { + ExchangeClient exchangeClient = exchangeClientSupplier.get( + new SimpleLocalMemoryContext(newSimpleAggregatedMemoryContext(), Query.class.getSimpleName()), + queryManager::outputTaskFailed, + getRetryPolicy(session)); + Query result = new Query(session, slug, queryManager, queryInfoUrl, exchangeClient, dataProcessorExecutor, timeoutExecutor, blockEncodingSerde); result.queryManager.addOutputInfoListener(result.getQueryId(), result::setQueryOutputInfo); @@ -354,7 +363,7 @@ public synchronized ListenableFuture waitForResults(long token, Ur private synchronized ListenableFuture getFutureStateChange() { // if the exchange client is open, wait for data - if (!exchangeClient.isClosed()) { + if (!exchangeClient.isFinished()) { return exchangeClient.isBlocked(); } @@ -433,7 +442,7 @@ private synchronized QueryResults getNextResult(long token, UriInfo uriInfo, Dat // (1) the query is not done AND the query state is not FAILED // OR // (2)there is more data to send (due to buffering) - if (queryInfo.getState() != FAILED && (!queryInfo.isFinalQueryInfo() || !exchangeClient.isClosed() || (lastResult != null && lastResult.getData() != null))) { + if (queryInfo.getState() != FAILED && (!queryInfo.isFinalQueryInfo() || !exchangeClient.isFinished() || (lastResult != null && lastResult.getData() != null))) { nextToken = OptionalLong.of(token + 1); } else { @@ -573,9 +582,7 @@ private synchronized void setQueryOutputInfo(QueryExecution.QueryOutputInfo outp types = outputInfo.getColumnTypes(); } - for (URI outputLocation : outputInfo.getBufferLocations()) { - exchangeClient.addLocation(outputLocation); - } + outputInfo.getBufferLocations().forEach(exchangeClient::addLocation); if (outputInfo.isNoMoreBufferLocations()) { exchangeClient.noMoreLocations(); } @@ -772,6 +779,8 @@ private static StageStats toStageStats(StageInfo stageInfo) .setProcessedRows(stageStats.getRawInputPositions()) .setProcessedBytes(stageStats.getRawInputDataSize().toBytes()) .setPhysicalInputBytes(stageStats.getPhysicalInputDataSize().toBytes()) + .setFailedTasks(stageStats.getFailedTasks()) + .setCoordinatorOnly(stageInfo.isCoordinatorOnly()) .setSubStages(subStages.build()) .build(); } diff --git a/core/trino-main/src/main/java/io/trino/server/remotetask/HttpRemoteTask.java b/core/trino-main/src/main/java/io/trino/server/remotetask/HttpRemoteTask.java index 7b28f4bab16d..fedf562dce19 100644 --- a/core/trino-main/src/main/java/io/trino/server/remotetask/HttpRemoteTask.java +++ b/core/trino-main/src/main/java/io/trino/server/remotetask/HttpRemoteTask.java @@ -255,7 +255,7 @@ public HttpRemoteTask( TaskInfo initialTask = createInitialTask(taskId, location, nodeId, bufferStates, new TaskStats(DateTime.now(), null)); this.dynamicFiltersFetcher = new DynamicFiltersFetcher( - this::failTask, + this::fail, taskId, location, taskStatusRefreshMaxWait, @@ -268,7 +268,7 @@ public HttpRemoteTask( dynamicFilterService); this.taskStatusFetcher = new ContinuousTaskStatusFetcher( - this::failTask, + this::fail, initialTask.getTaskStatus(), taskStatusRefreshMaxWait, taskStatusCodec, @@ -280,7 +280,7 @@ public HttpRemoteTask( stats); this.taskInfoFetcher = new TaskInfoFetcher( - this::failTask, + this::fail, taskStatusFetcher, initialTask, httpClient, @@ -795,7 +795,8 @@ public void onFailure(Throwable t) /** * Move the task directly to the failed state if there was a failure in this task */ - private void failTask(Throwable cause) + @Override + public synchronized void fail(Throwable cause) { TaskStatus taskStatus = getTaskStatus(); if (!taskStatus.getState().isDone()) { @@ -913,11 +914,11 @@ public void failed(Throwable cause) } } catch (Error e) { - failTask(e); + fail(e); throw e; } catch (RuntimeException e) { - failTask(e); + fail(e); } finally { sendUpdate(); @@ -929,7 +930,7 @@ public void failed(Throwable cause) public void fatal(Throwable cause) { try (SetThreadName ignored = new SetThreadName("UpdateResponseHandler-%s", taskId)) { - failTask(cause); + fail(cause); } } diff --git a/core/trino-main/src/main/java/io/trino/server/testing/TestingTrinoServer.java b/core/trino-main/src/main/java/io/trino/server/testing/TestingTrinoServer.java index 654c34e450db..75e6778eeed2 100644 --- a/core/trino-main/src/main/java/io/trino/server/testing/TestingTrinoServer.java +++ b/core/trino-main/src/main/java/io/trino/server/testing/TestingTrinoServer.java @@ -44,6 +44,8 @@ import io.trino.dispatcher.DispatchManager; import io.trino.eventlistener.EventListenerConfig; import io.trino.eventlistener.EventListenerManager; +import io.trino.execution.FailureInjector; +import io.trino.execution.FailureInjector.InjectedFailureType; import io.trino.execution.QueryInfo; import io.trino.execution.QueryManager; import io.trino.execution.SqlQueryManager; @@ -68,6 +70,7 @@ import io.trino.server.ShutdownAction; import io.trino.server.security.CertificateAuthenticatorManager; import io.trino.server.security.ServerSecurityModule; +import io.trino.spi.ErrorType; import io.trino.spi.Plugin; import io.trino.spi.QueryId; import io.trino.spi.eventlistener.EventListener; @@ -163,6 +166,7 @@ public static Builder builder() private final ShutdownAction shutdownAction; private final MBeanServer mBeanServer; private final boolean coordinator; + private final FailureInjector failureInjector; public static class TestShutdownAction implements ShutdownAction @@ -323,6 +327,7 @@ private TestingTrinoServer( shutdownAction = injector.getInstance(ShutdownAction.class); mBeanServer = injector.getInstance(MBeanServer.class); announcer = injector.getInstance(Announcer.class); + failureInjector = injector.getInstance(FailureInjector.class); accessControl.setSystemAccessControls(systemAccessControls); @@ -557,6 +562,23 @@ public T getInstance(Key key) return injector.getInstance(key); } + public void injectTaskFailure( + String traceToken, + int stageId, + int partitionId, + int attemptId, + InjectedFailureType injectionType, + Optional errorType) + { + failureInjector.injectTaskFailure( + traceToken, + stageId, + partitionId, + attemptId, + injectionType, + errorType); + } + private static void updateConnectorIdAnnouncement(Announcer announcer, CatalogName catalogName, InternalNodeManager nodeManager) { // diff --git a/core/trino-main/src/main/java/io/trino/split/RemoteSplit.java b/core/trino-main/src/main/java/io/trino/split/RemoteSplit.java index 1a652641d639..e043fd1a2b22 100644 --- a/core/trino-main/src/main/java/io/trino/split/RemoteSplit.java +++ b/core/trino-main/src/main/java/io/trino/split/RemoteSplit.java @@ -16,6 +16,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; +import io.trino.execution.TaskId; import io.trino.spi.HostAddress; import io.trino.spi.connector.ConnectorSplit; @@ -28,14 +29,22 @@ public class RemoteSplit implements ConnectorSplit { + private final TaskId taskId; private final URI location; @JsonCreator - public RemoteSplit(@JsonProperty("location") URI location) + public RemoteSplit(@JsonProperty("taskId") TaskId taskId, @JsonProperty("location") URI location) { + this.taskId = requireNonNull(taskId, "taskId is null"); this.location = requireNonNull(location, "location is null"); } + @JsonProperty + public TaskId getTaskId() + { + return taskId; + } + @JsonProperty public URI getLocation() { @@ -64,6 +73,7 @@ public List getAddresses() public String toString() { return toStringHelper(this) + .add("taskId", taskId) .add("location", location) .toString(); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java index d12908a34bec..03cbf841a3d3 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java @@ -72,6 +72,7 @@ import io.trino.operator.PartitionFunction; import io.trino.operator.PipelineExecutionStrategy; import io.trino.operator.RefreshMaterializedViewOperator.RefreshMaterializedViewOperatorFactory; +import io.trino.operator.RetryPolicy; import io.trino.operator.RowNumberOperator; import io.trino.operator.ScanFilterAndProjectOperator.ScanFilterAndProjectOperatorFactory; import io.trino.operator.SetBuilderOperator.SetBuilderOperatorFactory; @@ -859,6 +860,7 @@ public PhysicalOperation visitRemoteSource(RemoteSourceNode node, LocalExecution private PhysicalOperation createMergeSource(RemoteSourceNode node, LocalExecutionPlanContext context) { checkArgument(node.getOrderingScheme().isPresent(), "orderingScheme is absent"); + checkArgument(node.getRetryPolicy() == RetryPolicy.NONE, "unexpected retry policy: " + node.getRetryPolicy()); // merging remote source must have a single driver context.setDriverInstanceCount(1); @@ -897,7 +899,8 @@ private PhysicalOperation createRemoteSource(RemoteSourceNode node, LocalExecuti context.getNextOperatorId(), node.getId(), exchangeClientSupplier, - new PagesSerdeFactory(metadata.getBlockEncodingSerde(), isExchangeCompressionEnabled(session))); + new PagesSerdeFactory(metadata.getBlockEncodingSerde(), isExchangeCompressionEnabled(session)), + node.getRetryPolicy()); return new PhysicalOperation(operatorFactory, makeLayout(node), context, UNGROUPED_EXECUTION); } @@ -3068,7 +3071,7 @@ public PhysicalOperation visitSemiJoin(SemiJoinNode node, LocalExecutionPlanCont private Set getCoordinatorDynamicFilters(Set dynamicFilters, PlanNode node, TaskId taskId) { - if (!isBuildSideReplicated(node) || taskId.getId() == 0) { + if (!isBuildSideReplicated(node) || taskId.getPartitionId() == 0) { // replicated dynamic filters are collected by single stage task only return dynamicFilters; } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java index 61faf185924c..dce85da42cca 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java @@ -25,6 +25,7 @@ import io.trino.metadata.Metadata; import io.trino.metadata.TableHandle; import io.trino.metadata.TableProperties.TablePartitioning; +import io.trino.operator.RetryPolicy; import io.trino.spi.TrinoException; import io.trino.spi.TrinoWarning; import io.trino.spi.connector.ConnectorPartitionHandle; @@ -68,6 +69,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.getOnlyElement; import static io.trino.SystemSessionProperties.getQueryMaxStageCount; +import static io.trino.SystemSessionProperties.getRetryPolicy; import static io.trino.SystemSessionProperties.isDynamicScheduleForGroupedExecution; import static io.trino.SystemSessionProperties.isForceSingleNodeOutput; import static io.trino.operator.StageExecutionDescriptor.ungroupedExecution; @@ -356,13 +358,15 @@ else if (exchange.getType() == ExchangeNode.Type.REPARTITION) { context.get().setDistribution(partitioningScheme.getPartitioning().getHandle(), metadata, session); } - ImmutableList.Builder builder = ImmutableList.builder(); + ImmutableList.Builder childrenProperties = ImmutableList.builder(); + ImmutableList.Builder childrenBuilder = ImmutableList.builder(); for (int sourceIndex = 0; sourceIndex < exchange.getSources().size(); sourceIndex++) { FragmentProperties childProperties = new FragmentProperties(partitioningScheme.translateOutputLayout(exchange.getInputs().get(sourceIndex))); - builder.add(buildSubPlan(exchange.getSources().get(sourceIndex), childProperties, context)); + childrenProperties.add(childProperties); + childrenBuilder.add(buildSubPlan(exchange.getSources().get(sourceIndex), childProperties, context)); } - List children = builder.build(); + List children = childrenBuilder.build(); context.get().addChildren(children); List childrenIds = children.stream() @@ -370,7 +374,13 @@ else if (exchange.getType() == ExchangeNode.Type.REPARTITION) { .map(PlanFragment::getId) .collect(toImmutableList()); - return new RemoteSourceNode(exchange.getId(), childrenIds, exchange.getOutputSymbols(), exchange.getOrderingScheme(), exchange.getType()); + return new RemoteSourceNode( + exchange.getId(), + childrenIds, + exchange.getOutputSymbols(), + exchange.getOrderingScheme(), + exchange.getType(), + isWorkerCoordinatorBoundary(context.get(), childrenProperties.build()) ? getRetryPolicy(session) : RetryPolicy.NONE); } private SubPlan buildSubPlan(PlanNode node, FragmentProperties properties, RewriteContext context) @@ -379,6 +389,22 @@ private SubPlan buildSubPlan(PlanNode node, FragmentProperties properties, Rewri PlanNode child = context.rewrite(node, properties); return buildFragment(child, properties, planFragmentId); } + + private static boolean isWorkerCoordinatorBoundary(FragmentProperties fragmentProperties, List childFragmentsProperties) + { + if (!fragmentProperties.getPartitioningHandle().isCoordinatorOnly()) { + // receiver stage is not a coordinator stage + return false; + } + if (childFragmentsProperties.stream().allMatch(properties -> properties.getPartitioningHandle().isCoordinatorOnly())) { + // coordinator to coordinator exchange + return false; + } + checkArgument( + childFragmentsProperties.stream().noneMatch(properties -> properties.getPartitioningHandle().isCoordinatorOnly()), + "Plans are not expected to have a mix of coordinator only fragments and distributed fragments as siblings"); + return true; + } } private static class FragmentProperties diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/DistributedExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/SplitSourceFactory.java similarity index 87% rename from core/trino-main/src/main/java/io/trino/sql/planner/DistributedExecutionPlanner.java rename to core/trino-main/src/main/java/io/trino/sql/planner/SplitSourceFactory.java index fb6b13d4fe4b..f1c70db748c0 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/DistributedExecutionPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/SplitSourceFactory.java @@ -17,10 +17,7 @@ import com.google.common.collect.ImmutableMap; import io.airlift.log.Logger; import io.trino.Session; -import io.trino.execution.TableInfo; import io.trino.metadata.Metadata; -import io.trino.metadata.TableProperties; -import io.trino.metadata.TableSchema; import io.trino.operator.StageExecutionDescriptor; import io.trino.server.DynamicFilterService; import io.trino.spi.connector.Constraint; @@ -77,42 +74,43 @@ import java.util.Map; import java.util.Optional; -import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.Iterables.getOnlyElement; import static io.trino.spi.connector.ConnectorSplitManager.SplitSchedulingStrategy.GROUPED_SCHEDULING; import static io.trino.spi.connector.ConnectorSplitManager.SplitSchedulingStrategy.UNGROUPED_SCHEDULING; import static io.trino.spi.connector.Constraint.alwaysTrue; import static io.trino.spi.connector.DynamicFilter.EMPTY; import static io.trino.sql.ExpressionUtils.filterConjuncts; -import static io.trino.sql.planner.optimizations.PlanNodeSearcher.searchFrom; import static java.util.Objects.requireNonNull; -public class DistributedExecutionPlanner +public class SplitSourceFactory { - private static final Logger log = Logger.get(DistributedExecutionPlanner.class); + private static final Logger log = Logger.get(SplitSourceFactory.class); private final SplitManager splitManager; - private final Metadata metadata; private final DynamicFilterService dynamicFilterService; + private final Metadata metadata; private final TypeAnalyzer typeAnalyzer; @Inject - public DistributedExecutionPlanner(SplitManager splitManager, Metadata metadata, DynamicFilterService dynamicFilterService, TypeAnalyzer typeAnalyzer) + public SplitSourceFactory(SplitManager splitManager, DynamicFilterService dynamicFilterService, Metadata metadata, TypeAnalyzer typeAnalyzer) { this.splitManager = requireNonNull(splitManager, "splitManager is null"); - this.metadata = requireNonNull(metadata, "metadata is null"); this.dynamicFilterService = requireNonNull(dynamicFilterService, "dynamicFilterService is null"); + this.metadata = requireNonNull(metadata, "metadata is null"); this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); } - public StageExecutionPlan plan(SubPlan root, Session session) + public Map createSplitSources(Session session, PlanFragment fragment) { ImmutableList.Builder allSplitSources = ImmutableList.builder(); try { - return doPlan(root, session, allSplitSources); + // get splits for this fragment, this is lazy so split assignments aren't actually calculated here + return fragment.getRoot().accept( + new Visitor(session, fragment.getStageExecutionDescriptor(), TypeProvider.copyOf(fragment.getSymbols()), allSplitSources), + null); } catch (Throwable t) { - allSplitSources.build().forEach(DistributedExecutionPlanner::closeSplitSource); + allSplitSources.build().forEach(SplitSourceFactory::closeSplitSource); throw t; } } @@ -127,43 +125,6 @@ private static void closeSplitSource(SplitSource source) } } - private StageExecutionPlan doPlan(SubPlan root, Session session, ImmutableList.Builder allSplitSources) - { - PlanFragment currentFragment = root.getFragment(); - - // get splits for this fragment, this is lazy so split assignments aren't actually calculated here - Map splitSources = currentFragment.getRoot().accept( - new Visitor(session, currentFragment.getStageExecutionDescriptor(), TypeProvider.copyOf(currentFragment.getSymbols()), allSplitSources), - null); - - // create child stages - ImmutableList.Builder dependencies = ImmutableList.builder(); - for (SubPlan childPlan : root.getChildren()) { - dependencies.add(doPlan(childPlan, session, allSplitSources)); - } - - // extract TableInfo - Map tables = searchFrom(root.getFragment().getRoot()) - .where(TableScanNode.class::isInstance) - .findAll() - .stream() - .map(TableScanNode.class::cast) - .collect(toImmutableMap(PlanNode::getId, node -> getTableInfo(node, session))); - - return new StageExecutionPlan( - currentFragment, - splitSources, - dependencies.build(), - tables); - } - - private TableInfo getTableInfo(TableScanNode node, Session session) - { - TableSchema tableSchema = metadata.getTableSchema(session, node.getTable()); - TableProperties tableProperties = metadata.getTableProperties(session, node.getTable()); - return new TableInfo(tableSchema.getQualifiedName(), tableProperties.getPredicate()); - } - private final class Visitor extends PlanVisitor, Void> { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/StageExecutionPlan.java b/core/trino-main/src/main/java/io/trino/sql/planner/StageExecutionPlan.java deleted file mode 100644 index 3748783a109c..000000000000 --- a/core/trino-main/src/main/java/io/trino/sql/planner/StageExecutionPlan.java +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.sql.planner; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import io.trino.execution.TableInfo; -import io.trino.split.SplitSource; -import io.trino.sql.planner.plan.OutputNode; -import io.trino.sql.planner.plan.PlanNodeId; - -import java.util.List; -import java.util.Map; -import java.util.Optional; - -import static com.google.common.base.MoreObjects.toStringHelper; -import static com.google.common.base.Preconditions.checkState; -import static java.util.Objects.requireNonNull; - -public class StageExecutionPlan -{ - private final PlanFragment fragment; - private final Map splitSources; - private final List subStages; - private final Optional> fieldNames; - private final Map tables; - - public StageExecutionPlan( - PlanFragment fragment, - Map splitSources, - List subStages, Map tables) - { - this.fragment = requireNonNull(fragment, "fragment is null"); - this.splitSources = requireNonNull(splitSources, "splitSources is null"); - this.subStages = ImmutableList.copyOf(requireNonNull(subStages, "subStages is null")); - - fieldNames = (fragment.getRoot() instanceof OutputNode) ? - Optional.of(ImmutableList.copyOf(((OutputNode) fragment.getRoot()).getColumnNames())) : - Optional.empty(); - - this.tables = ImmutableMap.copyOf(requireNonNull(tables, "tables is null")); - } - - public List getFieldNames() - { - checkState(fieldNames.isPresent(), "cannot get field names from non-output stage"); - return fieldNames.get(); - } - - public PlanFragment getFragment() - { - return fragment; - } - - public Map getSplitSources() - { - return splitSources; - } - - public List getSubStages() - { - return subStages; - } - - public Map getTables() - { - return tables; - } - - public StageExecutionPlan withBucketToPartition(Optional bucketToPartition) - { - return new StageExecutionPlan(fragment.withBucketToPartition(bucketToPartition), splitSources, subStages, tables); - } - - @Override - public String toString() - { - return toStringHelper(this) - .add("fragment", fragment) - .add("splitSources", splitSources) - .add("subStages", subStages) - .toString(); - } -} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java index b3345337c0b0..0710270a7c83 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -464,7 +464,8 @@ public PlanAndMappings visitRemoteSource(RemoteSourceNode node, UnaliasContext c node.getSourceFragmentIds(), newOutputs, newOrderingScheme, - node.getExchangeType()), + node.getExchangeType(), + node.getRetryPolicy()), mapping); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/RemoteSourceNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/RemoteSourceNode.java index 17dd62267cb1..46cf21dff1a4 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/RemoteSourceNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/RemoteSourceNode.java @@ -16,6 +16,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; +import io.trino.operator.RetryPolicy; import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.Symbol; @@ -35,6 +36,7 @@ public class RemoteSourceNode private final List outputs; private final Optional orderingScheme; private final ExchangeNode.Type exchangeType; // This is needed to "unfragment" to compute stats correctly. + private final RetryPolicy retryPolicy; @JsonCreator public RemoteSourceNode( @@ -42,7 +44,8 @@ public RemoteSourceNode( @JsonProperty("sourceFragmentIds") List sourceFragmentIds, @JsonProperty("outputs") List outputs, @JsonProperty("orderingScheme") Optional orderingScheme, - @JsonProperty("exchangeType") ExchangeNode.Type exchangeType) + @JsonProperty("exchangeType") ExchangeNode.Type exchangeType, + @JsonProperty("retryPolicy") RetryPolicy retryPolicy) { super(id); @@ -52,11 +55,18 @@ public RemoteSourceNode( this.outputs = ImmutableList.copyOf(outputs); this.orderingScheme = requireNonNull(orderingScheme, "orderingScheme is null"); this.exchangeType = requireNonNull(exchangeType, "exchangeType is null"); + this.retryPolicy = requireNonNull(retryPolicy, "retryPolicy is null"); } - public RemoteSourceNode(PlanNodeId id, PlanFragmentId sourceFragmentId, List outputs, Optional orderingScheme, ExchangeNode.Type exchangeType) + public RemoteSourceNode( + PlanNodeId id, + PlanFragmentId sourceFragmentId, + List outputs, + Optional orderingScheme, + ExchangeNode.Type exchangeType, + RetryPolicy retryPolicy) { - this(id, ImmutableList.of(sourceFragmentId), outputs, orderingScheme, exchangeType); + this(id, ImmutableList.of(sourceFragmentId), outputs, orderingScheme, exchangeType, retryPolicy); } @Override @@ -90,6 +100,12 @@ public ExchangeNode.Type getExchangeType() return exchangeType; } + @JsonProperty("retryPolicy") + public RetryPolicy getRetryPolicy() + { + return retryPolicy; + } + @Override public R accept(PlanVisitor visitor, C context) { diff --git a/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java b/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java index 2fe35949fc14..e4756422f39a 100644 --- a/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java +++ b/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java @@ -51,6 +51,7 @@ import io.trino.eventlistener.EventListenerConfig; import io.trino.eventlistener.EventListenerManager; import io.trino.execution.DynamicFilterConfig; +import io.trino.execution.FailureInjector.InjectedFailureType; import io.trino.execution.Lifespan; import io.trino.execution.NodeTaskMap; import io.trino.execution.QueryManagerConfig; @@ -109,6 +110,7 @@ import io.trino.server.security.HeaderAuthenticatorManager; import io.trino.server.security.PasswordAuthenticatorConfig; import io.trino.server.security.PasswordAuthenticatorManager; +import io.trino.spi.ErrorType; import io.trino.spi.PageIndexerFactory; import io.trino.spi.PageSorter; import io.trino.spi.Plugin; @@ -733,6 +735,18 @@ public Lock getExclusiveLock() return lock.writeLock(); } + @Override + public void injectTaskFailure( + String traceToken, + int stageId, + int partitionId, + int attemptId, + InjectedFailureType injectionType, + Optional errorType) + { + throw new UnsupportedOperationException("failure injection is not supported"); + } + public List createDrivers(@Language("SQL") String sql, OutputFactory outputFactory, TaskContext taskContext) { return createDrivers(defaultSession, sql, outputFactory, taskContext); diff --git a/core/trino-main/src/main/java/io/trino/testing/MaterializedResult.java b/core/trino-main/src/main/java/io/trino/testing/MaterializedResult.java index ec9f22585964..4fd7474add90 100644 --- a/core/trino-main/src/main/java/io/trino/testing/MaterializedResult.java +++ b/core/trino-main/src/main/java/io/trino/testing/MaterializedResult.java @@ -18,6 +18,7 @@ import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slices; import io.trino.Session; +import io.trino.client.StatementStats; import io.trino.client.Warning; import io.trino.spi.Page; import io.trino.spi.PageBuilder; @@ -95,10 +96,11 @@ public class MaterializedResult private final Optional updateType; private final OptionalLong updateCount; private final List warnings; + private final Optional statementStats; public MaterializedResult(List rows, List types) { - this(rows, types, ImmutableMap.of(), ImmutableSet.of(), Optional.empty(), OptionalLong.empty(), ImmutableList.of()); + this(rows, types, ImmutableMap.of(), ImmutableSet.of(), Optional.empty(), OptionalLong.empty(), ImmutableList.of(), Optional.empty()); } public MaterializedResult( @@ -108,7 +110,8 @@ public MaterializedResult( Set resetSessionProperties, Optional updateType, OptionalLong updateCount, - List warnings) + List warnings, + Optional statementStats) { this.rows = ImmutableList.copyOf(requireNonNull(rows, "rows is null")); this.types = ImmutableList.copyOf(requireNonNull(types, "types is null")); @@ -117,6 +120,7 @@ public MaterializedResult( this.updateType = requireNonNull(updateType, "updateType is null"); this.updateCount = requireNonNull(updateCount, "updateCount is null"); this.warnings = requireNonNull(warnings, "warnings is null"); + this.statementStats = requireNonNull(statementStats, "statementStats is null"); } public int getRowCount() @@ -165,6 +169,11 @@ public List getWarnings() return warnings; } + public Optional getStatementStats() + { + return statementStats; + } + @Override public boolean equals(Object obj) { @@ -357,7 +366,8 @@ public MaterializedResult toTestTypes() resetSessionProperties, updateType, updateCount, - warnings); + warnings, + statementStats); } private static MaterializedRow convertToTestTypes(MaterializedRow trinoRow) diff --git a/core/trino-main/src/main/java/io/trino/testing/QueryRunner.java b/core/trino-main/src/main/java/io/trino/testing/QueryRunner.java index 1a04088e729d..bcfcda7ccf23 100644 --- a/core/trino-main/src/main/java/io/trino/testing/QueryRunner.java +++ b/core/trino-main/src/main/java/io/trino/testing/QueryRunner.java @@ -15,10 +15,12 @@ import io.trino.Session; import io.trino.cost.StatsCalculator; +import io.trino.execution.FailureInjector.InjectedFailureType; import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.Metadata; import io.trino.metadata.QualifiedObjectName; import io.trino.metadata.SqlFunction; +import io.trino.spi.ErrorType; import io.trino.spi.Plugin; import io.trino.split.PageSourceManager; import io.trino.split.SplitManager; @@ -31,6 +33,7 @@ import java.io.Closeable; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.concurrent.locks.Lock; public interface QueryRunner @@ -87,6 +90,14 @@ default Plan createPlan(Session session, @Language("SQL") String sql, WarningCol Lock getExclusiveLock(); + void injectTaskFailure( + String traceToken, + int stageId, + int partitionId, + int attemptId, + InjectedFailureType injectionType, + Optional errorType); + class MaterializedResultWithPlan { private final MaterializedResult materializedResult; diff --git a/core/trino-main/src/main/java/io/trino/testing/TestingTaskContext.java b/core/trino-main/src/main/java/io/trino/testing/TestingTaskContext.java index 5c6d7002e90d..64366c52f1e9 100644 --- a/core/trino-main/src/main/java/io/trino/testing/TestingTaskContext.java +++ b/core/trino-main/src/main/java/io/trino/testing/TestingTaskContext.java @@ -17,6 +17,7 @@ import io.airlift.stats.TestingGcMonitor; import io.airlift.units.DataSize; import io.trino.Session; +import io.trino.execution.StageId; import io.trino.execution.TaskId; import io.trino.execution.TaskStateMachine; import io.trino.memory.MemoryPool; @@ -60,7 +61,7 @@ public static TaskContext createTaskContext(Executor notificationExecutor, Sched public static TaskContext createTaskContext(QueryContext queryContext, Executor executor, Session session) { - return createTaskContext(queryContext, session, new TaskStateMachine(new TaskId("query", 0, 0), executor)); + return createTaskContext(queryContext, session, new TaskStateMachine(new TaskId(new StageId("query", 0), 0, 0), executor)); } private static TaskContext createTaskContext(QueryContext queryContext, Session session, TaskStateMachine taskStateMachine) @@ -96,7 +97,7 @@ private Builder(Executor notificationExecutor, ScheduledExecutorService yieldExe this.notificationExecutor = notificationExecutor; this.yieldExecutor = yieldExecutor; this.session = session; - this.taskStateMachine = new TaskStateMachine(new TaskId("query", 0, 0), notificationExecutor); + this.taskStateMachine = new TaskStateMachine(new TaskId(new StageId("query", 0), 0, 0), notificationExecutor); } public Builder setTaskStateMachine(TaskStateMachine taskStateMachine) diff --git a/core/trino-main/src/test/java/io/trino/execution/BenchmarkNodeScheduler.java b/core/trino-main/src/test/java/io/trino/execution/BenchmarkNodeScheduler.java index 44f29cdf6802..b89252294f3c 100644 --- a/core/trino-main/src/test/java/io/trino/execution/BenchmarkNodeScheduler.java +++ b/core/trino-main/src/test/java/io/trino/execution/BenchmarkNodeScheduler.java @@ -156,7 +156,7 @@ public void setup() for (int j = 0; j < MAX_SPLITS_PER_NODE + MAX_PENDING_SPLITS_PER_TASK_PER_NODE; j++) { initialSplits.add(new Split(CONNECTOR_ID, new TestSplitRemote(i), Lifespan.taskWide())); } - TaskId taskId = new TaskId("test", 1, i); + TaskId taskId = new TaskId(new StageId("test", 1), i, 0); MockRemoteTaskFactory.MockRemoteTask remoteTask = remoteTaskFactory.createTableScanTask(taskId, node, initialSplits.build(), nodeTaskMap.createPartitionedSplitCountTracker(node, taskId)); nodeTaskMap.addTask(node, remoteTask); taskMap.put(node, remoteTask); diff --git a/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java b/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java index 8d1adc1e2f26..8dd5e83d03e5 100644 --- a/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java +++ b/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java @@ -448,6 +448,13 @@ public void abort() clearSplits(); } + @Override + public void fail(Throwable cause) + { + taskStateMachine.failed(cause); + clearSplits(); + } + @Override public PartitionedSplitsInfo getPartitionedSplitsInfo() { diff --git a/core/trino-main/src/test/java/io/trino/execution/TestFailureInjectionConfig.java b/core/trino-main/src/test/java/io/trino/execution/TestFailureInjectionConfig.java new file mode 100644 index 000000000000..eb89e03e3d4c --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/execution/TestFailureInjectionConfig.java @@ -0,0 +1,51 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution; + +import com.google.common.collect.ImmutableMap; +import io.airlift.units.Duration; +import org.testng.annotations.Test; + +import java.util.Map; + +import static io.airlift.configuration.testing.ConfigAssertions.assertFullMapping; +import static io.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; +import static io.airlift.configuration.testing.ConfigAssertions.recordDefaults; +import static java.util.concurrent.TimeUnit.MINUTES; + +public class TestFailureInjectionConfig +{ + @Test + public void testDefaults() + { + assertRecordedDefaults(recordDefaults(FailureInjectionConfig.class) + .setRequestTimeout(new Duration(2, MINUTES)) + .setExpirationPeriod(new Duration(10, MINUTES))); + } + + @Test + public void testExplicitPropertyMappings() + { + Map properties = new ImmutableMap.Builder() + .put("failure-injection.request-timeout", "12m") + .put("failure-injection.expiration-period", "7m") + .build(); + + FailureInjectionConfig expected = new FailureInjectionConfig() + .setRequestTimeout(new Duration(12, MINUTES)) + .setExpirationPeriod(new Duration(7, MINUTES)); + + assertFullMapping(properties, expected); + } +} diff --git a/core/trino-main/src/test/java/io/trino/execution/TestMemoryRevokingScheduler.java b/core/trino-main/src/test/java/io/trino/execution/TestMemoryRevokingScheduler.java index 95a08294b395..14eebbb832ea 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestMemoryRevokingScheduler.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestMemoryRevokingScheduler.java @@ -293,7 +293,7 @@ private SqlTask newSqlTask(QueryId queryId) { QueryContext queryContext = getOrCreateQueryContext(queryId); - TaskId taskId = new TaskId(queryId.getId(), 0, idGeneator.incrementAndGet()); + TaskId taskId = new TaskId(new StageId(queryId.getId(), 0), idGeneator.incrementAndGet(), 0); URI location = URI.create("fake://task/" + taskId); return createSqlTask( diff --git a/core/trino-main/src/test/java/io/trino/execution/TestNodeScheduler.java b/core/trino-main/src/test/java/io/trino/execution/TestNodeScheduler.java index 5ba5a9f0e3c0..0eb58e5dd49d 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestNodeScheduler.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestNodeScheduler.java @@ -205,7 +205,7 @@ public void testTopologyAwareScheduling() MockRemoteTaskFactory remoteTaskFactory = new MockRemoteTaskFactory(remoteTaskExecutor, remoteTaskScheduledExecutor); int task = 0; for (InternalNode node : assignments.keySet()) { - TaskId taskId = new TaskId("test", 1, task); + TaskId taskId = new TaskId(new StageId("test", 1), task, 0); task++; MockRemoteTaskFactory.MockRemoteTask remoteTask = remoteTaskFactory.createTableScanTask(taskId, node, ImmutableList.copyOf(assignments.get(node)), nodeTaskMap.createPartitionedSplitCountTracker(node, taskId)); remoteTask.startSplits(25); @@ -322,11 +322,11 @@ public void testMaxSplitsPerNode() MockRemoteTaskFactory remoteTaskFactory = new MockRemoteTaskFactory(remoteTaskExecutor, remoteTaskScheduledExecutor); // Max out number of splits on node - TaskId taskId1 = new TaskId("test", 1, 1); + TaskId taskId1 = new TaskId(new StageId("test", 1), 1, 0); RemoteTask remoteTask1 = remoteTaskFactory.createTableScanTask(taskId1, newNode, initialSplits.build(), nodeTaskMap.createPartitionedSplitCountTracker(newNode, taskId1)); nodeTaskMap.addTask(newNode, remoteTask1); - TaskId taskId2 = new TaskId("test", 1, 2); + TaskId taskId2 = new TaskId(new StageId("test", 1), 2, 0); RemoteTask remoteTask2 = remoteTaskFactory.createTableScanTask(taskId2, newNode, initialSplits.build(), nodeTaskMap.createPartitionedSplitCountTracker(newNode, taskId2)); nodeTaskMap.addTask(newNode, remoteTask2); @@ -381,13 +381,13 @@ public void testMaxSplitsPerNodePerTask() MockRemoteTaskFactory remoteTaskFactory = new MockRemoteTaskFactory(remoteTaskExecutor, remoteTaskScheduledExecutor); for (InternalNode node : nodeManager.getActiveConnectorNodes(CONNECTOR_ID)) { // Max out number of splits on node - TaskId taskId = new TaskId("test", 1, 1); + TaskId taskId = new TaskId(new StageId("test", 1), 1, 0); RemoteTask remoteTask = remoteTaskFactory.createTableScanTask(taskId, node, initialSplits.build(), nodeTaskMap.createPartitionedSplitCountTracker(node, taskId)); nodeTaskMap.addTask(node, remoteTask); tasks.add(remoteTask); } - TaskId taskId = new TaskId("test", 1, 2); + TaskId taskId = new TaskId(new StageId("test", 1), 2, 0); RemoteTask newRemoteTask = remoteTaskFactory.createTableScanTask(taskId, newNode, initialSplits.build(), nodeTaskMap.createPartitionedSplitCountTracker(newNode, taskId)); // Max out pending splits on new node taskMap.put(newNode, newRemoteTask); @@ -418,7 +418,7 @@ public void testTaskCompletion() setUpNodes(); MockRemoteTaskFactory remoteTaskFactory = new MockRemoteTaskFactory(remoteTaskExecutor, remoteTaskScheduledExecutor); InternalNode chosenNode = Iterables.get(nodeManager.getActiveConnectorNodes(CONNECTOR_ID), 0); - TaskId taskId = new TaskId("test", 1, 1); + TaskId taskId = new TaskId(new StageId("test", 1), 1, 0); RemoteTask remoteTask = remoteTaskFactory.createTableScanTask( taskId, chosenNode, @@ -441,7 +441,7 @@ public void testSplitCount() MockRemoteTaskFactory remoteTaskFactory = new MockRemoteTaskFactory(remoteTaskExecutor, remoteTaskScheduledExecutor); InternalNode chosenNode = Iterables.get(nodeManager.getActiveConnectorNodes(CONNECTOR_ID), 0); - TaskId taskId1 = new TaskId("test", 1, 1); + TaskId taskId1 = new TaskId(new StageId("test", 1), 1, 0); RemoteTask remoteTask1 = remoteTaskFactory.createTableScanTask(taskId1, chosenNode, ImmutableList.of( @@ -449,7 +449,7 @@ public void testSplitCount() new Split(CONNECTOR_ID, new TestSplitRemote(), Lifespan.taskWide())), nodeTaskMap.createPartitionedSplitCountTracker(chosenNode, taskId1)); - TaskId taskId2 = new TaskId("test", 1, 2); + TaskId taskId2 = new TaskId(new StageId("test", 1), 2, 0); RemoteTask remoteTask2 = remoteTaskFactory.createTableScanTask( taskId2, chosenNode, @@ -777,7 +777,7 @@ public void testEmptyAssignmentWithFullNodes() MockRemoteTaskFactory remoteTaskFactory = new MockRemoteTaskFactory(remoteTaskExecutor, remoteTaskScheduledExecutor); int task = 0; for (InternalNode node : assignments1.keySet()) { - TaskId taskId = new TaskId("test", 1, task); + TaskId taskId = new TaskId(new StageId("test", 1), task, 0); task++; MockRemoteTaskFactory.MockRemoteTask remoteTask = remoteTaskFactory.createTableScanTask(taskId, node, ImmutableList.copyOf(assignments1.get(node)), nodeTaskMap.createPartitionedSplitCountTracker(node, taskId)); remoteTask.startSplits(20); @@ -818,7 +818,7 @@ public void testMaxUnacknowledgedSplitsPerTask() int counter = 1; for (InternalNode node : nodeManager.getActiveConnectorNodes(CONNECTOR_ID)) { // Max out number of unacknowledged splits on each task - TaskId taskId = new TaskId("test", 1, counter); + TaskId taskId = new TaskId(new StageId("test", 1), counter, 0); counter++; MockRemoteTaskFactory.MockRemoteTask remoteTask = remoteTaskFactory.createTableScanTask(taskId, node, initialSplits.build(), nodeTaskMap.createPartitionedSplitCountTracker(node, taskId)); nodeTaskMap.addTask(node, remoteTask); diff --git a/core/trino-main/src/test/java/io/trino/execution/TestSqlStageExecution.java b/core/trino-main/src/test/java/io/trino/execution/TestSqlStage.java similarity index 73% rename from core/trino-main/src/test/java/io/trino/execution/TestSqlStageExecution.java rename to core/trino-main/src/test/java/io/trino/execution/TestSqlStage.java index c6f76bf2949a..e27c012548ad 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestSqlStageExecution.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestSqlStage.java @@ -15,17 +15,16 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableMultimap; +import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.SettableFuture; import io.trino.client.NodeVersion; import io.trino.cost.StatsAndCosts; -import io.trino.execution.MockRemoteTaskFactory.MockRemoteTask; import io.trino.execution.scheduler.SplitSchedulerStats; -import io.trino.failuredetector.NoOpFailureDetector; import io.trino.metadata.InternalNode; -import io.trino.server.DynamicFilterService; +import io.trino.operator.RetryPolicy; import io.trino.spi.QueryId; import io.trino.spi.type.Type; -import io.trino.spi.type.TypeOperators; import io.trino.sql.planner.Partitioning; import io.trino.sql.planner.PartitioningScheme; import io.trino.sql.planner.PlanFragment; @@ -48,10 +47,9 @@ import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.trino.SessionTestUtils.TEST_SESSION; -import static io.trino.execution.SqlStageExecution.createSqlStageExecution; +import static io.trino.execution.SqlStage.createSqlStage; import static io.trino.execution.buffer.OutputBuffers.BufferType.ARBITRARY; import static io.trino.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers; -import static io.trino.metadata.MetadataManager.createTestMetadataManager; import static io.trino.operator.StageExecutionDescriptor.ungroupedExecution; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; @@ -64,7 +62,7 @@ import static org.testng.Assert.assertSame; import static org.testng.Assert.assertTrue; -public class TestSqlStageExecution +public class TestSqlStage { private ExecutorService executor; private ScheduledExecutorService scheduledExecutor; @@ -102,7 +100,7 @@ private void testFinalStageInfoInternal() NodeTaskMap nodeTaskMap = new NodeTaskMap(new FinalizerService()); StageId stageId = new StageId(new QueryId("query"), 0); - SqlStageExecution stage = createSqlStageExecution( + SqlStage stage = createSqlStage( stageId, createExchangePlanFragment(), ImmutableMap.of(), @@ -111,10 +109,7 @@ private void testFinalStageInfoInternal() true, nodeTaskMap, executor, - new NoOpFailureDetector(), - new DynamicFilterService(createTestMetadataManager(), new TypeOperators(), new DynamicFilterConfig()), new SplitSchedulerStats()); - stage.setOutputBuffers(createInitialEmptyOutputBuffers(ARBITRARY)); // add listener that fetches stage info when the final status is available SettableFuture finalStageInfo = SettableFuture.create(); @@ -133,7 +128,15 @@ private void testFinalStageInfoInternal() URI.create("http://10.0.0." + (i / 10_000) + ":" + (i % 10_000)), NodeVersion.UNKNOWN, false); - stage.scheduleTask(node, i); + stage.createTask( + node, + i, + 0, + Optional.empty(), + createInitialEmptyOutputBuffers(ARBITRARY), + ImmutableMultimap.of(), + ImmutableMultimap.of(), + ImmutableSet.of()); latch.countDown(); } } @@ -147,7 +150,7 @@ private void testFinalStageInfoInternal() // wait for some tasks to be created, and then abort the query latch.await(1, MINUTES); assertFalse(stage.getStageInfo().getTasks().isEmpty()); - stage.abort(); + stage.finish(); // once the final stage info is available, verify that it is complete StageInfo stageInfo = finalStageInfo.get(1, MINUTES); @@ -159,43 +162,6 @@ private void testFinalStageInfoInternal() addTasksTask.cancel(true); } - @Test - public void testIsAnyTaskBlocked() - { - NodeTaskMap nodeTaskMap = new NodeTaskMap(new FinalizerService()); - - StageId stageId = new StageId(new QueryId("query"), 0); - SqlStageExecution stage = createSqlStageExecution( - stageId, - createExchangePlanFragment(), - ImmutableMap.of(), - new MockRemoteTaskFactory(executor, scheduledExecutor), - TEST_SESSION, - true, - nodeTaskMap, - executor, - new NoOpFailureDetector(), - new DynamicFilterService(createTestMetadataManager(), new TypeOperators(), new DynamicFilterConfig()), - new SplitSchedulerStats()); - stage.setOutputBuffers(createInitialEmptyOutputBuffers(ARBITRARY)); - - InternalNode node1 = new InternalNode("other1", URI.create("http://127.0.0.1:11"), NodeVersion.UNKNOWN, false); - InternalNode node2 = new InternalNode("other2", URI.create("http://127.0.0.2:12"), NodeVersion.UNKNOWN, false); - MockRemoteTask task1 = (MockRemoteTask) stage.scheduleTask(node1, 1).get(); - MockRemoteTask task2 = (MockRemoteTask) stage.scheduleTask(node2, 2).get(); - - // both tasks' buffers are under utilized - assertFalse(stage.isAnyTaskBlocked()); - - // set one of the task's buffer to be over utilized - task1.setOutputBufferOverUtilized(true); - assertTrue(stage.isAnyTaskBlocked()); - - // set both the tasks' buffers to be over utilized - task2.setOutputBufferOverUtilized(true); - assertTrue(stage.isAnyTaskBlocked()); - } - private static PlanFragment createExchangePlanFragment() { PlanNode planNode = new RemoteSourceNode( @@ -203,7 +169,8 @@ private static PlanFragment createExchangePlanFragment() ImmutableList.of(new PlanFragmentId("source")), ImmutableList.of(new Symbol("column")), Optional.empty(), - REPARTITION); + REPARTITION, + RetryPolicy.NONE); ImmutableMap.Builder types = ImmutableMap.builder(); for (Symbol symbol : planNode.getOutputSymbols()) { diff --git a/core/trino-main/src/test/java/io/trino/execution/TestSqlTask.java b/core/trino-main/src/test/java/io/trino/execution/TestSqlTask.java index a9de6bf73ebd..ea962a0437cf 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestSqlTask.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestSqlTask.java @@ -339,7 +339,7 @@ public void testDynamicFilters() private SqlTask createInitialTask() { - TaskId taskId = new TaskId("query", 0, nextTaskId.incrementAndGet()); + TaskId taskId = new TaskId(new StageId("query", 0), nextTaskId.incrementAndGet(), 0); URI location = URI.create("fake://task/" + taskId); QueryContext queryContext = new QueryContext(new QueryId("query"), diff --git a/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskExecution.java b/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskExecution.java index f3ffcadee8f5..ea460581b070 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskExecution.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskExecution.java @@ -115,7 +115,7 @@ public class TestSqlTaskExecution private static final OutputBufferId OUTPUT_BUFFER_ID = new OutputBufferId(0); private static final CatalogName CONNECTOR_ID = new CatalogName("test"); private static final Duration ASSERT_WAIT_TIMEOUT = new Duration(1, HOURS); - public static final TaskId TASK_ID = new TaskId("query", 0, 0); + public static final TaskId TASK_ID = new TaskId(new StageId("query", 0), 0, 0); @DataProvider public static Object[][] executionStrategies() diff --git a/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskManager.java b/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskManager.java index 0271f4e067b7..05b684583ae5 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskManager.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskManager.java @@ -34,6 +34,7 @@ import io.trino.metadata.InternalNode; import io.trino.operator.ExchangeClient; import io.trino.operator.ExchangeClientSupplier; +import io.trino.operator.RetryPolicy; import io.trino.spi.QueryId; import io.trino.spiller.LocalSpillManager; import io.trino.spiller.NodeSpillConfig; @@ -66,7 +67,7 @@ public class TestSqlTaskManager { - private static final TaskId TASK_ID = new TaskId("query", 0, 1); + private static final TaskId TASK_ID = new TaskId(new StageId("query", 0), 1, 0); public static final OutputBufferId OUT = new OutputBufferId(0); private final TaskExecutor taskExecutor; @@ -242,8 +243,8 @@ public void testSessionPropertyMemoryLimitOverride() .setMaxQueryTotalMemoryPerNode(DataSize.ofBytes(4)); try (SqlTaskManager sqlTaskManager = createSqlTaskManager(new TaskManagerConfig(), memoryConfig)) { - TaskId reduceLimitsId = new TaskId("q1", 0, 1); - TaskId increaseLimitsId = new TaskId("q2", 0, 1); + TaskId reduceLimitsId = new TaskId(new StageId("q1", 0), 1, 0); + TaskId increaseLimitsId = new TaskId(new StageId("q2", 0), 1, 0); QueryContext reducesLimitsContext = sqlTaskManager.getQueryContext(reduceLimitsId.getQueryId()); QueryContext attemptsIncreaseContext = sqlTaskManager.getQueryContext(increaseLimitsId.getQueryId()); @@ -338,7 +339,7 @@ public static class MockExchangeClientSupplier implements ExchangeClientSupplier { @Override - public ExchangeClient get(LocalMemoryContext systemMemoryContext) + public ExchangeClient get(LocalMemoryContext systemMemoryContext, TaskFailureListener taskFailureListener, RetryPolicy retryPolicy) { throw new UnsupportedOperationException(); } diff --git a/core/trino-main/src/test/java/io/trino/execution/TestStageStateMachine.java b/core/trino-main/src/test/java/io/trino/execution/TestStageStateMachine.java index e3694ea733f4..9693c6d7b064 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestStageStateMachine.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestStageStateMachine.java @@ -35,7 +35,6 @@ import java.util.concurrent.ExecutorService; import static io.airlift.concurrent.Threads.daemonThreadsNamed; -import static io.trino.SessionTestUtils.TEST_SESSION; import static io.trino.operator.StageExecutionDescriptor.ungroupedExecution; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; @@ -76,14 +75,14 @@ public void testBasicStateChanges() assertTrue(stateMachine.transitionToScheduling()); assertState(stateMachine, StageState.SCHEDULING); - assertTrue(stateMachine.transitionToScheduled()); - assertState(stateMachine, StageState.SCHEDULED); - assertTrue(stateMachine.transitionToRunning()); assertState(stateMachine, StageState.RUNNING); - assertTrue(stateMachine.transitionToFlushing()); - assertState(stateMachine, StageState.FLUSHING); + assertTrue(stateMachine.transitionToPending()); + assertState(stateMachine, StageState.PENDING); + + assertTrue(stateMachine.transitionToRunning()); + assertState(stateMachine, StageState.RUNNING); assertTrue(stateMachine.transitionToFinished()); assertState(stateMachine, StageState.FINISHED); @@ -103,10 +102,6 @@ public void testPlanned() assertTrue(stateMachine.transitionToRunning()); assertState(stateMachine, StageState.RUNNING); - stateMachine = createStageStateMachine(); - assertTrue(stateMachine.transitionToFlushing()); - assertState(stateMachine, StageState.FLUSHING); - stateMachine = createStageStateMachine(); assertTrue(stateMachine.transitionToFinished()); assertState(stateMachine, StageState.FINISHED); @@ -114,14 +109,6 @@ public void testPlanned() stateMachine = createStageStateMachine(); assertTrue(stateMachine.transitionToFailed(FAILED_CAUSE)); assertState(stateMachine, StageState.FAILED); - - stateMachine = createStageStateMachine(); - assertTrue(stateMachine.transitionToAborted()); - assertState(stateMachine, StageState.ABORTED); - - stateMachine = createStageStateMachine(); - assertTrue(stateMachine.transitionToCanceled()); - assertState(stateMachine, StageState.CANCELED); } @Test @@ -134,19 +121,11 @@ public void testScheduling() assertFalse(stateMachine.transitionToScheduling()); assertState(stateMachine, StageState.SCHEDULING); - assertTrue(stateMachine.transitionToScheduled()); - assertState(stateMachine, StageState.SCHEDULED); - stateMachine = createStageStateMachine(); stateMachine.transitionToScheduling(); assertTrue(stateMachine.transitionToRunning()); assertState(stateMachine, StageState.RUNNING); - stateMachine = createStageStateMachine(); - stateMachine.transitionToScheduling(); - assertTrue(stateMachine.transitionToFlushing()); - assertState(stateMachine, StageState.FLUSHING); - stateMachine = createStageStateMachine(); stateMachine.transitionToScheduling(); assertTrue(stateMachine.transitionToFinished()); @@ -156,56 +135,6 @@ public void testScheduling() stateMachine.transitionToScheduling(); assertTrue(stateMachine.transitionToFailed(FAILED_CAUSE)); assertState(stateMachine, StageState.FAILED); - - stateMachine = createStageStateMachine(); - stateMachine.transitionToScheduling(); - assertTrue(stateMachine.transitionToAborted()); - assertState(stateMachine, StageState.ABORTED); - - stateMachine = createStageStateMachine(); - stateMachine.transitionToScheduling(); - assertTrue(stateMachine.transitionToCanceled()); - assertState(stateMachine, StageState.CANCELED); - } - - @Test - public void testScheduled() - { - StageStateMachine stateMachine = createStageStateMachine(); - assertTrue(stateMachine.transitionToScheduled()); - assertState(stateMachine, StageState.SCHEDULED); - - assertFalse(stateMachine.transitionToScheduling()); - assertState(stateMachine, StageState.SCHEDULED); - - assertFalse(stateMachine.transitionToScheduled()); - assertState(stateMachine, StageState.SCHEDULED); - - assertTrue(stateMachine.transitionToRunning()); - assertState(stateMachine, StageState.RUNNING); - - assertTrue(stateMachine.transitionToFlushing()); - assertState(stateMachine, StageState.FLUSHING); - - stateMachine = createStageStateMachine(); - stateMachine.transitionToScheduled(); - assertTrue(stateMachine.transitionToFinished()); - assertState(stateMachine, StageState.FINISHED); - - stateMachine = createStageStateMachine(); - stateMachine.transitionToScheduled(); - assertTrue(stateMachine.transitionToFailed(FAILED_CAUSE)); - assertState(stateMachine, StageState.FAILED); - - stateMachine = createStageStateMachine(); - stateMachine.transitionToScheduled(); - assertTrue(stateMachine.transitionToAborted()); - assertState(stateMachine, StageState.ABORTED); - - stateMachine = createStageStateMachine(); - stateMachine.transitionToScheduled(); - assertTrue(stateMachine.transitionToCanceled()); - assertState(stateMachine, StageState.CANCELED); } @Test @@ -218,74 +147,24 @@ public void testRunning() assertFalse(stateMachine.transitionToScheduling()); assertState(stateMachine, StageState.RUNNING); - assertFalse(stateMachine.transitionToScheduled()); - assertState(stateMachine, StageState.RUNNING); - assertFalse(stateMachine.transitionToRunning()); assertState(stateMachine, StageState.RUNNING); - assertTrue(stateMachine.transitionToFlushing()); - assertState(stateMachine, StageState.FLUSHING); - - stateMachine = createStageStateMachine(); - stateMachine.transitionToRunning(); - assertTrue(stateMachine.transitionToFinished()); - assertState(stateMachine, StageState.FINISHED); - - stateMachine = createStageStateMachine(); - stateMachine.transitionToRunning(); - assertTrue(stateMachine.transitionToFailed(FAILED_CAUSE)); - assertState(stateMachine, StageState.FAILED); + assertTrue(stateMachine.transitionToPending()); + assertState(stateMachine, StageState.PENDING); - stateMachine = createStageStateMachine(); - stateMachine.transitionToRunning(); - assertTrue(stateMachine.transitionToAborted()); - assertState(stateMachine, StageState.ABORTED); + assertTrue(stateMachine.transitionToRunning()); + assertState(stateMachine, StageState.RUNNING); stateMachine = createStageStateMachine(); stateMachine.transitionToRunning(); - assertTrue(stateMachine.transitionToCanceled()); - assertState(stateMachine, StageState.CANCELED); - } - - @Test - public void testFlushing() - { - StageStateMachine stateMachine = createStageStateMachine(); - assertTrue(stateMachine.transitionToFlushing()); - assertState(stateMachine, StageState.FLUSHING); - - assertFalse(stateMachine.transitionToScheduling()); - assertState(stateMachine, StageState.FLUSHING); - - assertFalse(stateMachine.transitionToScheduled()); - assertState(stateMachine, StageState.FLUSHING); - - assertFalse(stateMachine.transitionToRunning()); - assertState(stateMachine, StageState.FLUSHING); - - assertFalse(stateMachine.transitionToFlushing()); - assertState(stateMachine, StageState.FLUSHING); - - stateMachine = createStageStateMachine(); - stateMachine.transitionToFlushing(); assertTrue(stateMachine.transitionToFinished()); assertState(stateMachine, StageState.FINISHED); stateMachine = createStageStateMachine(); - stateMachine.transitionToFlushing(); + stateMachine.transitionToRunning(); assertTrue(stateMachine.transitionToFailed(FAILED_CAUSE)); assertState(stateMachine, StageState.FAILED); - - stateMachine = createStageStateMachine(); - stateMachine.transitionToFlushing(); - assertTrue(stateMachine.transitionToAborted()); - assertState(stateMachine, StageState.ABORTED); - - stateMachine = createStageStateMachine(); - stateMachine.transitionToFlushing(); - assertTrue(stateMachine.transitionToCanceled()); - assertState(stateMachine, StageState.CANCELED); } @Test @@ -306,24 +185,6 @@ public void testFailed() assertFinalState(stateMachine, StageState.FAILED); } - @Test - public void testAborted() - { - StageStateMachine stateMachine = createStageStateMachine(); - - assertTrue(stateMachine.transitionToAborted()); - assertFinalState(stateMachine, StageState.ABORTED); - } - - @Test - public void testCanceled() - { - StageStateMachine stateMachine = createStageStateMachine(); - - assertTrue(stateMachine.transitionToCanceled()); - assertFinalState(stateMachine, StageState.CANCELED); - } - private static void assertFinalState(StageStateMachine stateMachine, StageState expectedState) { assertTrue(expectedState.isDone()); @@ -333,24 +194,18 @@ private static void assertFinalState(StageStateMachine stateMachine, StageState assertFalse(stateMachine.transitionToScheduling()); assertState(stateMachine, expectedState); - assertFalse(stateMachine.transitionToScheduled()); + assertFalse(stateMachine.transitionToPending()); assertState(stateMachine, expectedState); assertFalse(stateMachine.transitionToRunning()); assertState(stateMachine, expectedState); - assertFalse(stateMachine.transitionToFlushing()); - assertState(stateMachine, expectedState); - assertFalse(stateMachine.transitionToFinished()); assertState(stateMachine, expectedState); assertFalse(stateMachine.transitionToFailed(FAILED_CAUSE)); assertState(stateMachine, expectedState); - assertFalse(stateMachine.transitionToAborted()); - assertState(stateMachine, expectedState); - // attempt to fail with another exception, which will fail assertFalse(stateMachine.transitionToFailed(new IOException("failure after finish"))); assertState(stateMachine, expectedState); @@ -359,7 +214,6 @@ private static void assertFinalState(StageStateMachine stateMachine, StageState private static void assertState(StageStateMachine stateMachine, StageState expectedState) { assertEquals(stateMachine.getStageId(), STAGE_ID); - assertSame(stateMachine.getSession(), TEST_SESSION); StageInfo stageInfo = stateMachine.getStageInfo(ImmutableList::of); assertEquals(stageInfo.getStageId(), STAGE_ID); @@ -383,7 +237,7 @@ private static void assertState(StageStateMachine stateMachine, StageState expec private StageStateMachine createStageStateMachine() { - return new StageStateMachine(STAGE_ID, TEST_SESSION, PLAN_FRAGMENT, ImmutableMap.of(), executor, new SplitSchedulerStats()); + return new StageStateMachine(STAGE_ID, PLAN_FRAGMENT, ImmutableMap.of(), executor, new SplitSchedulerStats()); } private static PlanFragment createValuesPlan() diff --git a/core/trino-main/src/test/java/io/trino/execution/TestStageStats.java b/core/trino-main/src/test/java/io/trino/execution/TestStageStats.java index 2776e04e123b..09d970a5baf7 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestStageStats.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestStageStats.java @@ -37,6 +37,7 @@ public class TestStageStats 4, 5, 6, + 1, 7, 8, @@ -108,6 +109,7 @@ private static void assertExpectedStageStats(StageStats actual) assertEquals(actual.getTotalTasks(), 4); assertEquals(actual.getRunningTasks(), 5); assertEquals(actual.getCompletedTasks(), 6); + assertEquals(actual.getFailedTasks(), 1); assertEquals(actual.getTotalDrivers(), 7); assertEquals(actual.getQueuedDrivers(), 8); diff --git a/core/trino-main/src/test/java/io/trino/execution/executor/SimulationController.java b/core/trino-main/src/test/java/io/trino/execution/executor/SimulationController.java index a78414d63a61..d8771580b4da 100644 --- a/core/trino-main/src/test/java/io/trino/execution/executor/SimulationController.java +++ b/core/trino-main/src/test/java/io/trino/execution/executor/SimulationController.java @@ -16,6 +16,7 @@ import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.ListMultimap; import com.google.common.collect.Multimaps; +import io.trino.execution.StageId; import io.trino.execution.TaskId; import io.trino.execution.executor.SimulationTask.IntermediateTask; import io.trino.execution.executor.SimulationTask.LeafTask; @@ -182,13 +183,13 @@ private void createTask(TaskSpecification specification) runningTasks.put(specification, new LeafTask( taskExecutor, specification, - new TaskId(specification.getName(), 0, runningTasks.get(specification).size() + completedTasks.get(specification).size()))); + new TaskId(new StageId(specification.getName(), 0), runningTasks.get(specification).size() + completedTasks.get(specification).size(), 0))); } else { runningTasks.put(specification, new IntermediateTask( taskExecutor, specification, - new TaskId(specification.getName(), 0, runningTasks.get(specification).size() + completedTasks.get(specification).size()))); + new TaskId(new StageId(specification.getName(), 0), runningTasks.get(specification).size() + completedTasks.get(specification).size(), 0))); } } diff --git a/core/trino-main/src/test/java/io/trino/execution/executor/TestTaskExecutor.java b/core/trino-main/src/test/java/io/trino/execution/executor/TestTaskExecutor.java index 4318b733aa62..48dfc81db0d7 100644 --- a/core/trino-main/src/test/java/io/trino/execution/executor/TestTaskExecutor.java +++ b/core/trino-main/src/test/java/io/trino/execution/executor/TestTaskExecutor.java @@ -19,6 +19,7 @@ import io.airlift.testing.TestingTicker; import io.airlift.units.Duration; import io.trino.execution.SplitRunner; +import io.trino.execution.StageId; import io.trino.execution.TaskId; import org.testng.annotations.Test; @@ -55,7 +56,7 @@ public void testTasksComplete() ticker.increment(20, MILLISECONDS); try { - TaskId taskId = new TaskId("test", 0, 0); + TaskId taskId = new TaskId(new StageId("test", 0), 0, 0); TaskHandle taskHandle = taskExecutor.addTask(taskId, () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty()); Phaser beginPhase = new Phaser(); @@ -149,8 +150,8 @@ public void testQuantaFairness() ticker.increment(20, MILLISECONDS); try { - TaskHandle shortQuantaTaskHandle = taskExecutor.addTask(new TaskId("shortQuanta", 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty()); - TaskHandle longQuantaTaskHandle = taskExecutor.addTask(new TaskId("longQuanta", 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty()); + TaskHandle shortQuantaTaskHandle = taskExecutor.addTask(new TaskId(new StageId("short_quanta", 0), 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty()); + TaskHandle longQuantaTaskHandle = taskExecutor.addTask(new TaskId(new StageId("long_quanta", 0), 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty()); Phaser endQuantaPhaser = new Phaser(); @@ -183,7 +184,7 @@ public void testLevelMovement() ticker.increment(20, MILLISECONDS); try { - TaskHandle testTaskHandle = taskExecutor.addTask(new TaskId("test", 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty()); + TaskHandle testTaskHandle = taskExecutor.addTask(new TaskId(new StageId("test", 0), 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty()); Phaser globalPhaser = new Phaser(); globalPhaser.bulkRegister(3); // 2 taskExecutor threads + test thread @@ -224,9 +225,9 @@ public void testLevelMultipliers() try { for (int i = 0; i < (LEVEL_THRESHOLD_SECONDS.length - 1); i++) { TaskHandle[] taskHandles = { - taskExecutor.addTask(new TaskId("test1", 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty()), - taskExecutor.addTask(new TaskId("test2", 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty()), - taskExecutor.addTask(new TaskId("test3", 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty()) + taskExecutor.addTask(new TaskId(new StageId("test1", 0), 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty()), + taskExecutor.addTask(new TaskId(new StageId("test2", 0), 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty()), + taskExecutor.addTask(new TaskId(new StageId("test3", 0), 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty()) }; // move task 0 to next level @@ -299,7 +300,7 @@ public void testTaskHandle() taskExecutor.start(); try { - TaskId taskId = new TaskId("test", 0, 0); + TaskId taskId = new TaskId(new StageId("test", 0), 0, 0); TaskHandle taskHandle = taskExecutor.addTask(taskId, () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty()); Phaser beginPhase = new Phaser(); @@ -331,8 +332,8 @@ public void testTaskHandle() public void testLevelContributionCap() { MultilevelSplitQueue splitQueue = new MultilevelSplitQueue(2); - TaskHandle handle0 = new TaskHandle(new TaskId("test0", 0, 0), splitQueue, () -> 1, 1, new Duration(1, SECONDS), OptionalInt.empty()); - TaskHandle handle1 = new TaskHandle(new TaskId("test1", 0, 0), splitQueue, () -> 1, 1, new Duration(1, SECONDS), OptionalInt.empty()); + TaskHandle handle0 = new TaskHandle(new TaskId(new StageId("test0", 0), 0, 0), splitQueue, () -> 1, 1, new Duration(1, SECONDS), OptionalInt.empty()); + TaskHandle handle1 = new TaskHandle(new TaskId(new StageId("test1", 0), 0, 0), splitQueue, () -> 1, 1, new Duration(1, SECONDS), OptionalInt.empty()); for (int i = 0; i < (LEVEL_THRESHOLD_SECONDS.length - 1); i++) { long levelAdvanceTime = SECONDS.toNanos(LEVEL_THRESHOLD_SECONDS[i + 1] - LEVEL_THRESHOLD_SECONDS[i]); @@ -351,7 +352,7 @@ public void testLevelContributionCap() public void testUpdateLevelWithCap() { MultilevelSplitQueue splitQueue = new MultilevelSplitQueue(2); - TaskHandle handle0 = new TaskHandle(new TaskId("test0", 0, 0), splitQueue, () -> 1, 1, new Duration(1, SECONDS), OptionalInt.empty()); + TaskHandle handle0 = new TaskHandle(new TaskId(new StageId("test0", 0), 0, 0), splitQueue, () -> 1, 1, new Duration(1, SECONDS), OptionalInt.empty()); long quantaNanos = MINUTES.toNanos(10); handle0.addScheduledNanos(quantaNanos); @@ -373,7 +374,7 @@ public void testMinMaxDriversPerTask() TaskExecutor taskExecutor = new TaskExecutor(4, 16, 1, maxDriversPerTask, splitQueue, ticker); taskExecutor.start(); try { - TaskHandle testTaskHandle = taskExecutor.addTask(new TaskId("test", 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty()); + TaskHandle testTaskHandle = taskExecutor.addTask(new TaskId(new StageId("test", 0), 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty()); // enqueue all batches of splits int batchCount = 4; @@ -414,7 +415,7 @@ public void testUserSpecifiedMaxDriversPerTask() taskExecutor.start(); try { // overwrite the max drivers per task to be 1 - TaskHandle testTaskHandle = taskExecutor.addTask(new TaskId("test", 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.of(1)); + TaskHandle testTaskHandle = taskExecutor.addTask(new TaskId(new StageId("test", 0), 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.of(1)); // enqueue all batches of splits int batchCount = 4; diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestBroadcastOutputBufferManager.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestBroadcastOutputBufferManager.java index e951a17a3f7e..97b0e3fc27d4 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestBroadcastOutputBufferManager.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestBroadcastOutputBufferManager.java @@ -13,13 +13,10 @@ */ package io.trino.execution.scheduler; -import com.google.common.collect.ImmutableList; import io.trino.execution.buffer.OutputBuffers; import io.trino.execution.buffer.OutputBuffers.OutputBufferId; import org.testng.annotations.Test; -import java.util.concurrent.atomic.AtomicReference; - import static io.trino.execution.buffer.OutputBuffers.BROADCAST_PARTITION_ID; import static io.trino.execution.buffer.OutputBuffers.BufferType.BROADCAST; import static io.trino.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers; @@ -30,34 +27,35 @@ public class TestBroadcastOutputBufferManager @Test public void test() { - AtomicReference outputBufferTarget = new AtomicReference<>(); - BroadcastOutputBufferManager hashOutputBufferManager = new BroadcastOutputBufferManager(outputBufferTarget::set); - assertEquals(outputBufferTarget.get(), createInitialEmptyOutputBuffers(BROADCAST)); + BroadcastOutputBufferManager hashOutputBufferManager = new BroadcastOutputBufferManager(); + assertEquals(hashOutputBufferManager.getOutputBuffers(), createInitialEmptyOutputBuffers(BROADCAST)); - hashOutputBufferManager.addOutputBuffers(ImmutableList.of(new OutputBufferId(0)), false); + hashOutputBufferManager.addOutputBuffer(new OutputBufferId(0)); OutputBuffers expectedOutputBuffers = createInitialEmptyOutputBuffers(BROADCAST).withBuffer(new OutputBufferId(0), BROADCAST_PARTITION_ID); - assertEquals(outputBufferTarget.get(), expectedOutputBuffers); + assertEquals(hashOutputBufferManager.getOutputBuffers(), expectedOutputBuffers); - hashOutputBufferManager.addOutputBuffers(ImmutableList.of(new OutputBufferId(1), new OutputBufferId(2)), false); + hashOutputBufferManager.addOutputBuffer(new OutputBufferId(1)); + hashOutputBufferManager.addOutputBuffer(new OutputBufferId(2)); expectedOutputBuffers = expectedOutputBuffers.withBuffer(new OutputBufferId(1), BROADCAST_PARTITION_ID); expectedOutputBuffers = expectedOutputBuffers.withBuffer(new OutputBufferId(2), BROADCAST_PARTITION_ID); - assertEquals(outputBufferTarget.get(), expectedOutputBuffers); + assertEquals(hashOutputBufferManager.getOutputBuffers(), expectedOutputBuffers); // set no more buffers - hashOutputBufferManager.addOutputBuffers(ImmutableList.of(new OutputBufferId(3)), true); + hashOutputBufferManager.addOutputBuffer(new OutputBufferId(3)); + hashOutputBufferManager.noMoreBuffers(); expectedOutputBuffers = expectedOutputBuffers.withBuffer(new OutputBufferId(3), BROADCAST_PARTITION_ID); expectedOutputBuffers = expectedOutputBuffers.withNoMoreBufferIds(); - assertEquals(outputBufferTarget.get(), expectedOutputBuffers); + assertEquals(hashOutputBufferManager.getOutputBuffers(), expectedOutputBuffers); // try to add another buffer, which should not result in an error // and output buffers should not change - hashOutputBufferManager.addOutputBuffers(ImmutableList.of(new OutputBufferId(5)), false); - assertEquals(outputBufferTarget.get(), expectedOutputBuffers); + hashOutputBufferManager.addOutputBuffer(new OutputBufferId(5)); + assertEquals(hashOutputBufferManager.getOutputBuffers(), expectedOutputBuffers); // try to set no more buffers again, which should not result in an error // and output buffers should not change - hashOutputBufferManager.addOutputBuffers(ImmutableList.of(new OutputBufferId(6)), true); - assertEquals(outputBufferTarget.get(), expectedOutputBuffers); + hashOutputBufferManager.addOutputBuffer(new OutputBufferId(6)); + assertEquals(hashOutputBufferManager.getOutputBuffers(), expectedOutputBuffers); } } diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFixedCountScheduler.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFixedCountScheduler.java index bba17755fdae..393204445b43 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFixedCountScheduler.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFixedCountScheduler.java @@ -18,6 +18,7 @@ import io.trino.execution.MockRemoteTaskFactory; import io.trino.execution.NodeTaskMap.PartitionedSplitCountTracker; import io.trino.execution.RemoteTask; +import io.trino.execution.StageId; import io.trino.execution.TaskId; import io.trino.metadata.InternalNode; import org.testng.annotations.AfterClass; @@ -63,7 +64,7 @@ public void testSingleNode() { FixedCountScheduler nodeScheduler = new FixedCountScheduler( (node, partition) -> Optional.of(taskFactory.createTableScanTask( - new TaskId("test", 1, 1), + new TaskId(new StageId("test", 1), 1, 0), node, ImmutableList.of(), new PartitionedSplitCountTracker(delta -> {}))), generateRandomNodes(1)); @@ -80,7 +81,7 @@ public void testMultipleNodes() { FixedCountScheduler nodeScheduler = new FixedCountScheduler( (node, partition) -> Optional.of(taskFactory.createTableScanTask( - new TaskId("test", 1, 1), + new TaskId(new StageId("test", 1), 1, 0), node, ImmutableList.of(), new PartitionedSplitCountTracker(delta -> {}))), generateRandomNodes(5)); diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestPartitionedOutputBufferManager.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestPartitionedOutputBufferManager.java index 0fe667dd32a5..e209078c9dda 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestPartitionedOutputBufferManager.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestPartitionedOutputBufferManager.java @@ -13,13 +13,11 @@ */ package io.trino.execution.scheduler; -import com.google.common.collect.ImmutableList; import io.trino.execution.buffer.OutputBuffers; import io.trino.execution.buffer.OutputBuffers.OutputBufferId; import org.testng.annotations.Test; import java.util.Map; -import java.util.concurrent.atomic.AtomicReference; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -32,30 +30,28 @@ public class TestPartitionedOutputBufferManager @Test public void test() { - AtomicReference outputBufferTarget = new AtomicReference<>(); - - PartitionedOutputBufferManager hashOutputBufferManager = new PartitionedOutputBufferManager(FIXED_HASH_DISTRIBUTION, 4, outputBufferTarget::set); + PartitionedOutputBufferManager hashOutputBufferManager = new PartitionedOutputBufferManager(FIXED_HASH_DISTRIBUTION, 4); // output buffers are set immediately when the manager is created - assertOutputBuffers(outputBufferTarget.get()); + assertOutputBuffers(hashOutputBufferManager.getOutputBuffers()); // add buffers, which does not cause an error - hashOutputBufferManager.addOutputBuffers(ImmutableList.of(new OutputBufferId(0)), false); - assertOutputBuffers(outputBufferTarget.get()); - hashOutputBufferManager.addOutputBuffers(ImmutableList.of(new OutputBufferId(3)), true); - assertOutputBuffers(outputBufferTarget.get()); + hashOutputBufferManager.addOutputBuffer(new OutputBufferId(0)); + assertOutputBuffers(hashOutputBufferManager.getOutputBuffers()); + hashOutputBufferManager.addOutputBuffer(new OutputBufferId(3)); + assertOutputBuffers(hashOutputBufferManager.getOutputBuffers()); // try to a buffer out side of the partition range, which should result in an error - assertThatThrownBy(() -> hashOutputBufferManager.addOutputBuffers(ImmutableList.of(new OutputBufferId(5)), false)) + assertThatThrownBy(() -> hashOutputBufferManager.addOutputBuffer(new OutputBufferId(5))) .isInstanceOf(IllegalStateException.class) .hasMessage("Unexpected new output buffer 5"); - assertOutputBuffers(outputBufferTarget.get()); + assertOutputBuffers(hashOutputBufferManager.getOutputBuffers()); // try to a buffer out side of the partition range, which should result in an error - assertThatThrownBy(() -> hashOutputBufferManager.addOutputBuffers(ImmutableList.of(new OutputBufferId(6)), true)) + assertThatThrownBy(() -> hashOutputBufferManager.addOutputBuffer(new OutputBufferId(6))) .isInstanceOf(IllegalStateException.class) .hasMessage("Unexpected new output buffer 6"); - assertOutputBuffers(outputBufferTarget.get()); + assertOutputBuffers(hashOutputBufferManager.getOutputBuffers()); } private static void assertOutputBuffers(OutputBuffers outputBuffers) diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestPhasedExecutionSchedule.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestPhasedExecutionSchedule.java index b72254e1a51d..efc82be346c0 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestPhasedExecutionSchedule.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestPhasedExecutionSchedule.java @@ -18,6 +18,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import io.trino.cost.StatsAndCosts; +import io.trino.operator.RetryPolicy; import io.trino.spi.type.Type; import io.trino.sql.planner.Partitioning; import io.trino.sql.planner.PartitioningScheme; @@ -156,7 +157,8 @@ private static PlanFragment createExchangePlanFragment(String name, PlanFragment .collect(toImmutableList()), fragments[0].getPartitioningScheme().getOutputLayout(), Optional.empty(), - REPARTITION); + REPARTITION, + RetryPolicy.NONE); return createFragment(planNode); } @@ -166,7 +168,7 @@ private static PlanFragment createUnionPlanFragment(String name, PlanFragment... PlanNode planNode = new UnionNode( new PlanNodeId(name + "_id"), Stream.of(fragments) - .map(fragment -> new RemoteSourceNode(new PlanNodeId(fragment.getId().toString()), fragment.getId(), fragment.getPartitioningScheme().getOutputLayout(), Optional.empty(), REPARTITION)) + .map(fragment -> new RemoteSourceNode(new PlanNodeId(fragment.getId().toString()), fragment.getId(), fragment.getPartitioningScheme().getOutputLayout(), Optional.empty(), REPARTITION, RetryPolicy.NONE)) .collect(toImmutableList()), ImmutableListMultimap.of(), ImmutableList.of()); @@ -185,7 +187,7 @@ private static PlanFragment createBroadcastJoinPlanFragment(String name, PlanFra false, Optional.empty()); - RemoteSourceNode remote = new RemoteSourceNode(new PlanNodeId("build_id"), buildFragment.getId(), ImmutableList.of(), Optional.empty(), REPLICATE); + RemoteSourceNode remote = new RemoteSourceNode(new PlanNodeId("build_id"), buildFragment.getId(), ImmutableList.of(), Optional.empty(), REPLICATE, RetryPolicy.NONE); PlanNode join = new JoinNode( new PlanNodeId(name + "_id"), INNER, @@ -208,8 +210,8 @@ private static PlanFragment createBroadcastJoinPlanFragment(String name, PlanFra private static PlanFragment createJoinPlanFragment(JoinNode.Type joinType, String name, PlanFragment buildFragment, PlanFragment probeFragment) { - RemoteSourceNode probe = new RemoteSourceNode(new PlanNodeId("probe_id"), probeFragment.getId(), ImmutableList.of(), Optional.empty(), REPARTITION); - RemoteSourceNode build = new RemoteSourceNode(new PlanNodeId("build_id"), buildFragment.getId(), ImmutableList.of(), Optional.empty(), REPARTITION); + RemoteSourceNode probe = new RemoteSourceNode(new PlanNodeId("probe_id"), probeFragment.getId(), ImmutableList.of(), Optional.empty(), REPARTITION, RetryPolicy.NONE); + RemoteSourceNode build = new RemoteSourceNode(new PlanNodeId("build_id"), buildFragment.getId(), ImmutableList.of(), Optional.empty(), REPARTITION, RetryPolicy.NONE); PlanNode planNode = new JoinNode( new PlanNodeId(name + "_id"), joinType, diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java index da911dbd7a60..ab0b53d8afc2 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java @@ -16,7 +16,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Iterables; import io.airlift.units.Duration; import io.trino.Session; import io.trino.client.NodeVersion; @@ -28,17 +27,17 @@ import io.trino.execution.NodeTaskMap; import io.trino.execution.PartitionedSplitsInfo; import io.trino.execution.RemoteTask; -import io.trino.execution.SqlStageExecution; +import io.trino.execution.SqlStage; import io.trino.execution.StageId; import io.trino.execution.TableExecuteContextManager; import io.trino.execution.TableInfo; -import io.trino.execution.buffer.OutputBuffers.OutputBufferId; import io.trino.failuredetector.NoOpFailureDetector; import io.trino.metadata.InMemoryNodeManager; import io.trino.metadata.InternalNode; import io.trino.metadata.InternalNodeManager; import io.trino.metadata.Metadata; import io.trino.metadata.QualifiedObjectName; +import io.trino.operator.RetryPolicy; import io.trino.server.DynamicFilterService; import io.trino.spi.QueryId; import io.trino.spi.connector.ConnectorPartitionHandle; @@ -49,12 +48,10 @@ import io.trino.spi.predicate.TupleDomain; import io.trino.spi.type.TypeOperators; import io.trino.split.ConnectorAwareSplitSource; -import io.trino.split.SplitSource; import io.trino.sql.DynamicFilters; import io.trino.sql.planner.Partitioning; import io.trino.sql.planner.PartitioningScheme; import io.trino.sql.planner.PlanFragment; -import io.trino.sql.planner.StageExecutionPlan; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.SymbolAllocator; import io.trino.sql.planner.plan.DynamicFilterId; @@ -85,8 +82,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.trino.SessionTestUtils.TEST_SESSION; -import static io.trino.execution.buffer.OutputBuffers.BufferType.PARTITIONED; -import static io.trino.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers; +import static io.trino.execution.scheduler.PipelinedStageExecution.createPipelinedStageExecution; import static io.trino.execution.scheduler.ScheduleResult.BlockedReason.SPLIT_QUEUES_FULL; import static io.trino.execution.scheduler.SourcePartitionedScheduler.newSourcePartitionedSchedulerAsStageScheduler; import static io.trino.metadata.MetadataManager.createTestMetadataManager; @@ -96,6 +92,7 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.DynamicFilters.createDynamicFilterExpression; +import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION; import static io.trino.sql.planner.plan.ExchangeNode.Type.REPLICATE; @@ -113,7 +110,7 @@ public class TestSourcePartitionedScheduler { - public static final OutputBufferId OUT = new OutputBufferId(0); + private static final PlanNodeId TABLE_SCAN_NODE_ID = new PlanNodeId("plan_id"); private static final CatalogName CONNECTOR_ID = TEST_TABLE_HANDLE.getCatalogName(); private static final QueryId QUERY_ID = new QueryId("query"); private static final DynamicFilterId DYNAMIC_FILTER_ID = new DynamicFilterId("filter1"); @@ -151,11 +148,11 @@ public void destroyExecutor() @Test public void testScheduleNoSplits() { - StageExecutionPlan plan = createPlan(createFixedSplitSource(0, TestingSplit::createRemoteSplit)); + PlanFragment plan = createFragment(); NodeTaskMap nodeTaskMap = new NodeTaskMap(finalizerService); - SqlStageExecution stage = createSqlStageExecution(plan, nodeTaskMap); + PipelinedStageExecution stage = createStageExecution(plan, nodeTaskMap); - StageScheduler scheduler = getSourcePartitionedScheduler(plan, stage, nodeManager, nodeTaskMap, 1); + StageScheduler scheduler = getSourcePartitionedScheduler(createFixedSplitSource(0, TestingSplit::createRemoteSplit), stage, nodeManager, nodeTaskMap, 1); ScheduleResult scheduleResult = scheduler.schedule(); @@ -168,11 +165,11 @@ public void testScheduleNoSplits() @Test public void testScheduleSplitsOneAtATime() { - StageExecutionPlan plan = createPlan(createFixedSplitSource(60, TestingSplit::createRemoteSplit)); + PlanFragment plan = createFragment(); NodeTaskMap nodeTaskMap = new NodeTaskMap(finalizerService); - SqlStageExecution stage = createSqlStageExecution(plan, nodeTaskMap); + PipelinedStageExecution stage = createStageExecution(plan, nodeTaskMap); - StageScheduler scheduler = getSourcePartitionedScheduler(plan, stage, nodeManager, nodeTaskMap, 1); + StageScheduler scheduler = getSourcePartitionedScheduler(createFixedSplitSource(60, TestingSplit::createRemoteSplit), stage, nodeManager, nodeTaskMap, 1); for (int i = 0; i < 60; i++) { ScheduleResult scheduleResult = scheduler.schedule(); @@ -206,11 +203,11 @@ public void testScheduleSplitsOneAtATime() @Test public void testScheduleSplitsBatched() { - StageExecutionPlan plan = createPlan(createFixedSplitSource(60, TestingSplit::createRemoteSplit)); + PlanFragment plan = createFragment(); NodeTaskMap nodeTaskMap = new NodeTaskMap(finalizerService); - SqlStageExecution stage = createSqlStageExecution(plan, nodeTaskMap); + PipelinedStageExecution stage = createStageExecution(plan, nodeTaskMap); - StageScheduler scheduler = getSourcePartitionedScheduler(plan, stage, nodeManager, nodeTaskMap, 7); + StageScheduler scheduler = getSourcePartitionedScheduler(createFixedSplitSource(60, TestingSplit::createRemoteSplit), stage, nodeManager, nodeTaskMap, 7); for (int i = 0; i <= (60 / 7); i++) { ScheduleResult scheduleResult = scheduler.schedule(); @@ -244,11 +241,11 @@ public void testScheduleSplitsBatched() @Test public void testScheduleSplitsBlock() { - StageExecutionPlan plan = createPlan(createFixedSplitSource(80, TestingSplit::createRemoteSplit)); + PlanFragment plan = createFragment(); NodeTaskMap nodeTaskMap = new NodeTaskMap(finalizerService); - SqlStageExecution stage = createSqlStageExecution(plan, nodeTaskMap); + PipelinedStageExecution stage = createStageExecution(plan, nodeTaskMap); - StageScheduler scheduler = getSourcePartitionedScheduler(plan, stage, nodeManager, nodeTaskMap, 1); + StageScheduler scheduler = getSourcePartitionedScheduler(createFixedSplitSource(80, TestingSplit::createRemoteSplit), stage, nodeManager, nodeTaskMap, 1); // schedule first 60 splits, which will cause the scheduler to block for (int i = 0; i <= 60; i++) { @@ -311,11 +308,11 @@ public void testScheduleSplitsBlock() public void testScheduleSlowSplitSource() { QueuedSplitSource queuedSplitSource = new QueuedSplitSource(TestingSplit::createRemoteSplit); - StageExecutionPlan plan = createPlan(queuedSplitSource); + PlanFragment plan = createFragment(); NodeTaskMap nodeTaskMap = new NodeTaskMap(finalizerService); - SqlStageExecution stage = createSqlStageExecution(plan, nodeTaskMap); + PipelinedStageExecution stage = createStageExecution(plan, nodeTaskMap); - StageScheduler scheduler = getSourcePartitionedScheduler(plan, stage, nodeManager, nodeTaskMap, 1); + StageScheduler scheduler = getSourcePartitionedScheduler(queuedSplitSource, stage, nodeManager, nodeTaskMap, 1); // schedule with no splits - will block ScheduleResult scheduleResult = scheduler.schedule(); @@ -336,13 +333,13 @@ public void testNoNodes() InMemoryNodeManager nodeManager = new InMemoryNodeManager(); NodeScheduler nodeScheduler = new NodeScheduler(new UniformNodeSelectorFactory(nodeManager, new NodeSchedulerConfig().setIncludeCoordinator(false), nodeTaskMap)); - StageExecutionPlan plan = createPlan(createFixedSplitSource(20, TestingSplit::createRemoteSplit)); - SqlStageExecution stage = createSqlStageExecution(plan, nodeTaskMap); + PlanFragment plan = createFragment(); + PipelinedStageExecution stage = createStageExecution(plan, nodeTaskMap); StageScheduler scheduler = newSourcePartitionedSchedulerAsStageScheduler( stage, - Iterables.getOnlyElement(plan.getSplitSources().keySet()), - Iterables.getOnlyElement(plan.getSplitSources().values()), + TABLE_SCAN_NODE_ID, + new ConnectorAwareSplitSource(CONNECTOR_ID, createFixedSplitSource(20, TestingSplit::createRemoteSplit)), new DynamicSplitPlacementPolicy(nodeScheduler.createNodeSelector(session, Optional.of(CONNECTOR_ID)), stage::getAllTasks), 2, new DynamicFilterService(metadata, typeOperators, new DynamicFilterConfig()), @@ -364,9 +361,9 @@ public void testBalancedSplitAssignment() NodeTaskMap nodeTaskMap = new NodeTaskMap(finalizerService); // Schedule 15 splits - there are 3 nodes, each node should get 5 splits - StageExecutionPlan firstPlan = createPlan(createFixedSplitSource(15, TestingSplit::createRemoteSplit)); - SqlStageExecution firstStage = createSqlStageExecution(firstPlan, nodeTaskMap); - StageScheduler firstScheduler = getSourcePartitionedScheduler(firstPlan, firstStage, nodeManager, nodeTaskMap, 200); + PlanFragment firstPlan = createFragment(); + PipelinedStageExecution firstStage = createStageExecution(firstPlan, nodeTaskMap); + StageScheduler firstScheduler = getSourcePartitionedScheduler(createFixedSplitSource(15, TestingSplit::createRemoteSplit), firstStage, nodeManager, nodeTaskMap, 200); ScheduleResult scheduleResult = firstScheduler.schedule(); assertEffectivelyFinished(scheduleResult, firstScheduler); @@ -383,9 +380,9 @@ public void testBalancedSplitAssignment() nodeManager.addNode(CONNECTOR_ID, additionalNode); // Schedule 5 splits in another query. Since the new node does not have any splits, all 5 splits are assigned to the new node - StageExecutionPlan secondPlan = createPlan(createFixedSplitSource(5, TestingSplit::createRemoteSplit)); - SqlStageExecution secondStage = createSqlStageExecution(secondPlan, nodeTaskMap); - StageScheduler secondScheduler = getSourcePartitionedScheduler(secondPlan, secondStage, nodeManager, nodeTaskMap, 200); + PlanFragment secondPlan = createFragment(); + PipelinedStageExecution secondStage = createStageExecution(secondPlan, nodeTaskMap); + StageScheduler secondScheduler = getSourcePartitionedScheduler(createFixedSplitSource(5, TestingSplit::createRemoteSplit), secondStage, nodeManager, nodeTaskMap, 200); scheduleResult = secondScheduler.schedule(); assertEffectivelyFinished(scheduleResult, secondScheduler); @@ -411,14 +408,14 @@ public void testNewTaskScheduledWhenChildStageBufferIsUnderutilized() new InternalNode("other3", URI.create("http://127.0.0.1:13"), NodeVersion.UNKNOWN, false)); NodeScheduler nodeScheduler = new NodeScheduler(new UniformNodeSelectorFactory(nodeManager, new NodeSchedulerConfig().setIncludeCoordinator(false), nodeTaskMap, new Duration(0, SECONDS))); - StageExecutionPlan plan = createPlan(createFixedSplitSource(500, TestingSplit::createRemoteSplit)); - SqlStageExecution stage = createSqlStageExecution(plan, nodeTaskMap); + PlanFragment plan = createFragment(); + PipelinedStageExecution stage = createStageExecution(plan, nodeTaskMap); // setting under utilized child output buffer StageScheduler scheduler = newSourcePartitionedSchedulerAsStageScheduler( stage, - Iterables.getOnlyElement(plan.getSplitSources().keySet()), - Iterables.getOnlyElement(plan.getSplitSources().values()), + TABLE_SCAN_NODE_ID, + new ConnectorAwareSplitSource(CONNECTOR_ID, createFixedSplitSource(500, TestingSplit::createRemoteSplit)), new DynamicSplitPlacementPolicy(nodeScheduler.createNodeSelector(session, Optional.of(CONNECTOR_ID)), stage::getAllTasks), 500, new DynamicFilterService(metadata, typeOperators, new DynamicFilterConfig()), @@ -455,14 +452,14 @@ public void testNoNewTaskScheduledWhenChildStageBufferIsOverutilized() new InternalNode("other3", URI.create("http://127.0.0.1:13"), NodeVersion.UNKNOWN, false)); NodeScheduler nodeScheduler = new NodeScheduler(new UniformNodeSelectorFactory(nodeManager, new NodeSchedulerConfig().setIncludeCoordinator(false), nodeTaskMap, new Duration(0, SECONDS))); - StageExecutionPlan plan = createPlan(createFixedSplitSource(400, TestingSplit::createRemoteSplit)); - SqlStageExecution stage = createSqlStageExecution(plan, nodeTaskMap); + PlanFragment plan = createFragment(); + PipelinedStageExecution stage = createStageExecution(plan, nodeTaskMap); // setting over utilized child output buffer StageScheduler scheduler = newSourcePartitionedSchedulerAsStageScheduler( stage, - Iterables.getOnlyElement(plan.getSplitSources().keySet()), - Iterables.getOnlyElement(plan.getSplitSources().values()), + TABLE_SCAN_NODE_ID, + new ConnectorAwareSplitSource(CONNECTOR_ID, createFixedSplitSource(400, TestingSplit::createRemoteSplit)), new DynamicSplitPlacementPolicy(nodeScheduler.createNodeSelector(session, Optional.of(CONNECTOR_ID)), stage::getAllTasks), 400, new DynamicFilterService(metadata, typeOperators, new DynamicFilterConfig()), @@ -490,9 +487,9 @@ public void testNoNewTaskScheduledWhenChildStageBufferIsOverutilized() @Test public void testDynamicFiltersUnblockedOnBlockedBuildSource() { - StageExecutionPlan plan = createPlan(createBlockedSplitSource()); + PlanFragment plan = createFragment(); NodeTaskMap nodeTaskMap = new NodeTaskMap(finalizerService); - SqlStageExecution stage = createSqlStageExecution(plan, nodeTaskMap); + PipelinedStageExecution stage = createStageExecution(plan, nodeTaskMap); NodeScheduler nodeScheduler = new NodeScheduler(new UniformNodeSelectorFactory(nodeManager, new NodeSchedulerConfig().setIncludeCoordinator(false), nodeTaskMap)); DynamicFilterService dynamicFilterService = new DynamicFilterService(metadata, typeOperators, new DynamicFilterConfig()); dynamicFilterService.registerQuery( @@ -503,8 +500,8 @@ public void testDynamicFiltersUnblockedOnBlockedBuildSource() ImmutableSet.of(DYNAMIC_FILTER_ID)); StageScheduler scheduler = newSourcePartitionedSchedulerAsStageScheduler( stage, - Iterables.getOnlyElement(plan.getSplitSources().keySet()), - Iterables.getOnlyElement(plan.getSplitSources().values()), + TABLE_SCAN_NODE_ID, + new ConnectorAwareSplitSource(CONNECTOR_ID, createBlockedSplitSource()), new DynamicSplitPlacementPolicy(nodeScheduler.createNodeSelector(session, Optional.of(CONNECTOR_ID)), stage::getAllTasks), 2, dynamicFilterService, @@ -531,7 +528,7 @@ public void testDynamicFiltersUnblockedOnBlockedBuildSource() assertEquals(scheduleResult.getSplitsScheduled(), 0); } - private static void assertPartitionedSplitCount(SqlStageExecution stage, int expectedPartitionedSplitCount) + private static void assertPartitionedSplitCount(PipelinedStageExecution stage, int expectedPartitionedSplitCount) { assertEquals(stage.getAllTasks().stream().mapToInt(remoteTask -> remoteTask.getPartitionedSplitsInfo().getCount()).sum(), expectedPartitionedSplitCount); } @@ -552,8 +549,8 @@ private static void assertEffectivelyFinished(ScheduleResult scheduleResult, Sta } private StageScheduler getSourcePartitionedScheduler( - StageExecutionPlan plan, - SqlStageExecution stage, + ConnectorSplitSource splitSource, + PipelinedStageExecution stage, InternalNodeManager nodeManager, NodeTaskMap nodeTaskMap, int splitBatchSize) @@ -564,13 +561,11 @@ private StageScheduler getSourcePartitionedScheduler( .setMaxPendingSplitsPerTask(0); NodeScheduler nodeScheduler = new NodeScheduler(new UniformNodeSelectorFactory(nodeManager, nodeSchedulerConfig, nodeTaskMap)); - PlanNodeId sourceNode = Iterables.getOnlyElement(plan.getSplitSources().keySet()); - SplitSource splitSource = Iterables.getOnlyElement(plan.getSplitSources().values()); - SplitPlacementPolicy placementPolicy = new DynamicSplitPlacementPolicy(nodeScheduler.createNodeSelector(session, Optional.of(splitSource.getCatalogName())), stage::getAllTasks); + SplitPlacementPolicy placementPolicy = new DynamicSplitPlacementPolicy(nodeScheduler.createNodeSelector(session, Optional.of(CONNECTOR_ID)), stage::getAllTasks); return newSourcePartitionedSchedulerAsStageScheduler( stage, - sourceNode, - splitSource, + TABLE_SCAN_NODE_ID, + new ConnectorAwareSplitSource(CONNECTOR_ID, splitSource), placementPolicy, splitBatchSize, new DynamicFilterService(metadata, typeOperators, new DynamicFilterConfig()), @@ -578,15 +573,14 @@ private StageScheduler getSourcePartitionedScheduler( () -> false); } - private static StageExecutionPlan createPlan(ConnectorSplitSource splitSource) + private static PlanFragment createFragment() { Symbol symbol = new Symbol("column"); Symbol buildSymbol = new Symbol("buildColumn"); // table scan with splitCount splits - PlanNodeId tableScanNodeId = new PlanNodeId("plan_id"); TableScanNode tableScan = TableScanNode.newInstance( - tableScanNodeId, + TABLE_SCAN_NODE_ID, TEST_TABLE_HANDLE, ImmutableList.of(symbol), ImmutableMap.of(symbol, new TestingColumnHandle("column")), @@ -597,8 +591,8 @@ private static StageExecutionPlan createPlan(ConnectorSplitSource splitSource) tableScan, createDynamicFilterExpression(TEST_SESSION, createTestMetadataManager(), DYNAMIC_FILTER_ID, VARCHAR, symbol.toSymbolReference())); - RemoteSourceNode remote = new RemoteSourceNode(new PlanNodeId("remote_id"), new PlanFragmentId("plan_fragment_id"), ImmutableList.of(buildSymbol), Optional.empty(), REPLICATE); - PlanFragment testFragment = new PlanFragment( + RemoteSourceNode remote = new RemoteSourceNode(new PlanNodeId("remote_id"), new PlanFragmentId("plan_fragment_id"), ImmutableList.of(buildSymbol), Optional.empty(), REPLICATE, RetryPolicy.NONE); + return new PlanFragment( new PlanFragmentId("plan_id"), new JoinNode(new PlanNodeId("join_id"), INNER, @@ -617,17 +611,11 @@ private static StageExecutionPlan createPlan(ConnectorSplitSource splitSource) Optional.empty()), ImmutableMap.of(symbol, VARCHAR), SOURCE_DISTRIBUTION, - ImmutableList.of(tableScanNodeId), + ImmutableList.of(TABLE_SCAN_NODE_ID), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)), ungroupedExecution(), StatsAndCosts.empty(), Optional.empty()); - - return new StageExecutionPlan( - testFragment, - ImmutableMap.of(tableScanNodeId, new ConnectorAwareSplitSource(CONNECTOR_ID, splitSource)), - ImmutableList.of(), - ImmutableMap.of(tableScanNodeId, new TableInfo(new QualifiedObjectName("test", "test", "test"), TupleDomain.all()))); } private static ConnectorSplitSource createBlockedSplitSource() @@ -663,26 +651,31 @@ private static ConnectorSplitSource createFixedSplitSource(int splitCount, Suppl return new FixedSplitSource(splits.build()); } - private SqlStageExecution createSqlStageExecution(StageExecutionPlan tableScanPlan, NodeTaskMap nodeTaskMap) + private PipelinedStageExecution createStageExecution(PlanFragment fragment, NodeTaskMap nodeTaskMap) { StageId stageId = new StageId(QUERY_ID, 0); - SqlStageExecution stage = SqlStageExecution.createSqlStageExecution(stageId, - tableScanPlan.getFragment(), - tableScanPlan.getTables(), + SqlStage stage = SqlStage.createSqlStage(stageId, + fragment, + ImmutableMap.of(TABLE_SCAN_NODE_ID, new TableInfo(new QualifiedObjectName("test", "test", "test"), TupleDomain.all())), new MockRemoteTaskFactory(queryExecutor, scheduledExecutor), TEST_SESSION, true, nodeTaskMap, queryExecutor, - new NoOpFailureDetector(), - new DynamicFilterService(metadata, typeOperators, new DynamicFilterConfig()), new SplitSchedulerStats()); - - stage.setOutputBuffers(createInitialEmptyOutputBuffers(PARTITIONED) - .withBuffer(OUT, 0) - .withNoMoreBufferIds()); - - return stage; + ImmutableMap.Builder outputBuffers = ImmutableMap.builder(); + outputBuffers.put(fragment.getId(), new PartitionedOutputBufferManager(FIXED_HASH_DISTRIBUTION, 1)); + fragment.getRemoteSourceNodes().stream() + .flatMap(node -> node.getSourceFragmentIds().stream()) + .forEach(fragmentId -> outputBuffers.put(fragmentId, new PartitionedOutputBufferManager(FIXED_HASH_DISTRIBUTION, 10))); + return createPipelinedStageExecution( + stage, + outputBuffers.build(), + TaskLifecycleListener.NO_OP, + new NoOpFailureDetector(), + queryExecutor, + Optional.of(new int[] {0}), + 0); } private static class QueuedSplitSource diff --git a/core/trino-main/src/test/java/io/trino/memory/TestMemoryTracking.java b/core/trino-main/src/test/java/io/trino/memory/TestMemoryTracking.java index 2f4cd3e8500b..5bd65e2f9c19 100644 --- a/core/trino-main/src/test/java/io/trino/memory/TestMemoryTracking.java +++ b/core/trino-main/src/test/java/io/trino/memory/TestMemoryTracking.java @@ -16,6 +16,7 @@ import io.airlift.stats.TestingGcMonitor; import io.airlift.units.DataSize; import io.trino.ExceededMemoryLimitException; +import io.trino.execution.StageId; import io.trino.execution.TaskId; import io.trino.execution.TaskStateMachine; import io.trino.memory.context.LocalMemoryContext; @@ -107,7 +108,7 @@ public void setUpTest() queryMaxSpillSize, spillSpaceTracker); taskContext = queryContext.addTaskContext( - new TaskStateMachine(new TaskId("query", 0, 0), notificationExecutor), + new TaskStateMachine(new TaskId(new StageId("query", 0), 0, 0), notificationExecutor), testSessionBuilder().build(), () -> {}, true, diff --git a/core/trino-main/src/test/java/io/trino/operator/MockExchangeRequestProcessor.java b/core/trino-main/src/test/java/io/trino/operator/MockExchangeRequestProcessor.java index ca5c1214f0a7..16a22087f65e 100644 --- a/core/trino-main/src/test/java/io/trino/operator/MockExchangeRequestProcessor.java +++ b/core/trino-main/src/test/java/io/trino/operator/MockExchangeRequestProcessor.java @@ -38,6 +38,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; import static com.google.common.base.Preconditions.checkState; import static com.google.common.net.HttpHeaders.CONTENT_TYPE; @@ -48,6 +49,7 @@ import static io.trino.server.InternalHeaders.TRINO_BUFFER_COMPLETE; import static io.trino.server.InternalHeaders.TRINO_PAGE_NEXT_TOKEN; import static io.trino.server.InternalHeaders.TRINO_PAGE_TOKEN; +import static io.trino.server.InternalHeaders.TRINO_TASK_FAILED; import static io.trino.server.InternalHeaders.TRINO_TASK_INSTANCE_ID; import static io.trino.server.PagesResponseWriter.SERIALIZED_PAGES_MAGIC; import static org.testng.Assert.assertEquals; @@ -73,11 +75,21 @@ public void addPage(URI location, Page page) buffers.getUnchecked(location).addPage(page); } + public void addPage(URI location, SerializedPage page) + { + buffers.getUnchecked(location).addPage(page); + } + public void setComplete(URI location) { buffers.getUnchecked(location).setCompleted(); } + public void setFailed(URI location, RuntimeException failure) + { + buffers.getUnchecked(location).setFailed(failure); + } + @Override public Response handle(Request request) { @@ -112,12 +124,14 @@ public Response handle(Request request) return new TestingResponse( status, - ImmutableListMultimap.of( - CONTENT_TYPE, TRINO_PAGES, - TRINO_TASK_INSTANCE_ID, String.valueOf(result.getTaskInstanceId()), - TRINO_PAGE_TOKEN, String.valueOf(result.getToken()), - TRINO_PAGE_NEXT_TOKEN, String.valueOf(result.getNextToken()), - TRINO_BUFFER_COMPLETE, String.valueOf(result.isBufferComplete())), + ImmutableListMultimap.builder() + .put(CONTENT_TYPE, TRINO_PAGES) + .put(TRINO_TASK_INSTANCE_ID, String.valueOf(result.getTaskInstanceId())) + .put(TRINO_PAGE_TOKEN, String.valueOf(result.getToken())) + .put(TRINO_PAGE_NEXT_TOKEN, String.valueOf(result.getNextToken())) + .put(TRINO_BUFFER_COMPLETE, String.valueOf(result.isBufferComplete())) + .put(TRINO_TASK_FAILED, "false") + .build(), bytes); } @@ -151,6 +165,7 @@ private static class MockBuffer private final AtomicBoolean completed = new AtomicBoolean(); private final AtomicLong token = new AtomicLong(); private final BlockingQueue serializedPages = new LinkedBlockingQueue<>(); + private final AtomicReference failure = new AtomicReference<>(); private MockBuffer(URI location) { @@ -162,6 +177,12 @@ public void setCompleted() completed.set(true); } + public synchronized void addPage(SerializedPage page) + { + checkState(completed.get() != Boolean.TRUE, "Location %s is complete", location); + serializedPages.add(page); + } + public synchronized void addPage(Page page) { checkState(completed.get() != Boolean.TRUE, "Location %s is complete", location); @@ -170,6 +191,11 @@ public synchronized void addPage(Page page) } } + public void setFailed(RuntimeException t) + { + failure.set(t); + } + public BufferResult getPages(long sequenceId, DataSize maxSize) { // if location is complete return GONE @@ -177,6 +203,11 @@ public BufferResult getPages(long sequenceId, DataSize maxSize) return BufferResult.emptyResults(TASK_INSTANCE_ID, token.get(), true); } + RuntimeException failure = this.failure.get(); + if (failure != null) { + throw failure; + } + assertEquals(sequenceId, token.get(), "token"); // wait for a single page to arrive diff --git a/core/trino-main/src/test/java/io/trino/operator/TestDeduplicationExchangeClientBuffer.java b/core/trino-main/src/test/java/io/trino/operator/TestDeduplicationExchangeClientBuffer.java new file mode 100644 index 000000000000..4cbb70bc3f08 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/operator/TestDeduplicationExchangeClientBuffer.java @@ -0,0 +1,476 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Multimap; +import com.google.common.collect.Sets; +import com.google.common.util.concurrent.ListenableFuture; +import io.airlift.slice.Slice; +import io.airlift.units.DataSize; +import io.trino.execution.StageId; +import io.trino.execution.TaskId; +import io.trino.execution.buffer.PageCodecMarker; +import io.trino.execution.buffer.SerializedPage; +import io.trino.spi.TrinoException; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.Map; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static io.airlift.slice.Slices.utf8Slice; +import static io.airlift.units.DataSize.Unit.BYTE; +import static io.airlift.units.DataSize.Unit.KILOBYTE; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + +public class TestDeduplicationExchangeClientBuffer +{ + private static final DataSize ONE_KB = DataSize.of(1, KILOBYTE); + + @Test + public void testIsBlocked() + { + // immediate close + try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + ListenableFuture blocked = buffer.isBlocked(); + assertBlocked(blocked); + buffer.close(); + assertNotBlocked(blocked); + } + + // empty set of tasks + try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + ListenableFuture blocked = buffer.isBlocked(); + assertBlocked(blocked); + buffer.noMoreTasks(); + assertNotBlocked(blocked); + } + + // single task finishes before noMoreTasks + try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + ListenableFuture blocked = buffer.isBlocked(); + assertBlocked(blocked); + + TaskId taskId = createTaskId(0, 0); + buffer.addTask(taskId); + assertBlocked(blocked); + + buffer.taskFinished(taskId); + assertBlocked(blocked); + + buffer.noMoreTasks(); + assertNotBlocked(blocked); + } + + // single task finishes after noMoreTasks + try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + ListenableFuture blocked = buffer.isBlocked(); + assertBlocked(blocked); + + TaskId taskId = createTaskId(0, 0); + buffer.addTask(taskId); + assertBlocked(blocked); + + buffer.noMoreTasks(); + assertBlocked(blocked); + + buffer.taskFinished(taskId); + assertNotBlocked(blocked); + } + + // single task fails before noMoreTasks + try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + ListenableFuture blocked = buffer.isBlocked(); + assertBlocked(blocked); + + TaskId taskId = createTaskId(0, 0); + buffer.addTask(taskId); + assertBlocked(blocked); + + buffer.taskFailed(taskId, new RuntimeException()); + assertBlocked(blocked); + + buffer.noMoreTasks(); + assertNotBlocked(blocked); + } + + // single task fails after noMoreTasks + try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + ListenableFuture blocked = buffer.isBlocked(); + assertBlocked(blocked); + + TaskId taskId = createTaskId(0, 0); + buffer.addTask(taskId); + assertBlocked(blocked); + + buffer.noMoreTasks(); + assertBlocked(blocked); + + buffer.taskFailed(taskId, new RuntimeException()); + assertNotBlocked(blocked); + } + + // cancelled blocked future doesn't affect other blocked futures + try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + ListenableFuture blocked1 = buffer.isBlocked(); + ListenableFuture blocked2 = buffer.isBlocked(); + assertBlocked(blocked1); + assertBlocked(blocked2); + + blocked2.cancel(true); + + assertBlocked(blocked1); + assertNotBlocked(blocked2); + } + } + + @Test + public void testPollPage() + { + testPollPages(ImmutableListMultimap.of(), ImmutableMap.of(), ImmutableList.of()); + testPollPages( + ImmutableListMultimap.builder() + .put(createTaskId(0, 0), createPage("p0a0v0")) + .build(), + ImmutableMap.of(), + ImmutableList.of("p0a0v0")); + testPollPages( + ImmutableListMultimap.builder() + .put(createTaskId(0, 0), createPage("p0a0v0")) + .put(createTaskId(0, 1), createPage("p0a1v0")) + .build(), + ImmutableMap.of(), + ImmutableList.of("p0a1v0")); + testPollPages( + ImmutableListMultimap.builder() + .put(createTaskId(0, 0), createPage("p0a0v0")) + .put(createTaskId(1, 0), createPage("p1a0v0")) + .put(createTaskId(0, 1), createPage("p0a1v0")) + .build(), + ImmutableMap.of(), + ImmutableList.of("p0a1v0")); + testPollPages( + ImmutableListMultimap.builder() + .put(createTaskId(0, 0), createPage("p0a0v0")) + .put(createTaskId(1, 0), createPage("p1a0v0")) + .put(createTaskId(0, 1), createPage("p0a1v0")) + .build(), + ImmutableMap.of( + createTaskId(2, 0), + new RuntimeException("error")), + ImmutableList.of("p0a1v0")); + RuntimeException error = new RuntimeException("error"); + testPollPagesFailure( + ImmutableListMultimap.builder() + .put(createTaskId(0, 0), createPage("p0a0v0")) + .put(createTaskId(1, 0), createPage("p1a0v0")) + .put(createTaskId(0, 1), createPage("p0a1v0")) + .build(), + ImmutableMap.of( + createTaskId(2, 2), + error), + error); + testPollPagesFailure( + ImmutableListMultimap.builder() + .put(createTaskId(0, 0), createPage("p0a0v0")) + .put(createTaskId(1, 0), createPage("p1a0v0")) + .put(createTaskId(0, 1), createPage("p0a1v0")) + .build(), + ImmutableMap.of( + createTaskId(0, 1), + error), + error); + } + + private static void testPollPages(Multimap pages, Map failures, List expectedValues) + { + List actualPages = pollPages(pages, failures); + List actualValues = actualPages.stream() + .map(SerializedPage::getSlice) + .map(Slice::toStringUtf8) + .collect(toImmutableList()); + assertThat(actualValues).containsExactlyInAnyOrderElementsOf(expectedValues); + } + + private static void testPollPagesFailure(Multimap pages, Map failures, Throwable expectedFailure) + { + assertThatThrownBy(() -> pollPages(pages, failures)).isEqualTo(expectedFailure); + } + + private static List pollPages(Multimap pages, Map failures) + { + try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + for (TaskId taskId : Sets.union(pages.keySet(), failures.keySet())) { + buffer.addTask(taskId); + } + for (Map.Entry page : pages.entries()) { + buffer.addPages(page.getKey(), ImmutableList.of(page.getValue())); + } + for (Map.Entry failure : failures.entrySet()) { + buffer.taskFailed(failure.getKey(), failure.getValue()); + } + for (TaskId taskId : Sets.difference(pages.keySet(), failures.keySet())) { + buffer.taskFinished(taskId); + } + buffer.noMoreTasks(); + + ImmutableList.Builder result = ImmutableList.builder(); + while (true) { + SerializedPage page = buffer.pollPage(); + if (page == null) { + break; + } + result.add(page); + } + assertTrue(buffer.isFinished()); + return result.build(); + } + } + + @Test + public void testRemovePagesForPreviousAttempts() + { + try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + assertEquals(buffer.getRetainedSizeInBytes(), 0); + + TaskId partition0Attempt0 = createTaskId(0, 0); + TaskId partition1Attempt0 = createTaskId(1, 0); + TaskId partition0Attempt1 = createTaskId(0, 1); + + SerializedPage page1 = createPage("textofrandomlength"); + SerializedPage page2 = createPage("textwithdifferentlength"); + SerializedPage page3 = createPage("smalltext"); + + buffer.addTask(partition0Attempt0); + buffer.addPages(partition0Attempt0, ImmutableList.of(page1)); + buffer.addTask(partition1Attempt0); + buffer.addPages(partition1Attempt0, ImmutableList.of(page2)); + + assertThat(buffer.getRetainedSizeInBytes()).isGreaterThan(0); + assertEquals(buffer.getRetainedSizeInBytes(), page1.getRetainedSizeInBytes() + page2.getRetainedSizeInBytes()); + + buffer.addTask(partition0Attempt1); + assertEquals(buffer.getRetainedSizeInBytes(), 0); + + buffer.addPages(partition0Attempt1, ImmutableList.of(page3)); + assertEquals(buffer.getRetainedSizeInBytes(), page3.getRetainedSizeInBytes()); + } + } + + @Test + public void testBufferOverflow() + { + try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), DataSize.of(100, BYTE), RetryPolicy.QUERY)) { + TaskId task = createTaskId(0, 0); + + SerializedPage page1 = createPage("1234"); + SerializedPage page2 = createPage("123456789"); + assertThat(page1.getRetainedSizeInBytes()).isLessThanOrEqualTo(100); + assertThat(page1.getRetainedSizeInBytes() + page2.getRetainedSizeInBytes()).isGreaterThan(100); + + buffer.addTask(task); + buffer.addPages(task, ImmutableList.of(page1)); + + assertFalse(buffer.isFinished()); + assertBlocked(buffer.isBlocked()); + assertEquals(buffer.getRetainedSizeInBytes(), page1.getRetainedSizeInBytes()); + + buffer.addPages(task, ImmutableList.of(page2)); + assertTrue(buffer.isFinished()); + assertNotBlocked(buffer.isBlocked()); + assertEquals(buffer.getRetainedSizeInBytes(), 0); + assertEquals(buffer.getBufferedPageCount(), 0); + + assertThatThrownBy(buffer::pollPage) + .isInstanceOf(TrinoException.class) + .hasMessage("Retries for queries with large result set currently unsupported"); + } + } + + @Test + public void testIsFinished() + { + // close right away + try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + assertFalse(buffer.isFinished()); + buffer.close(); + assertTrue(buffer.isFinished()); + } + + // 0 tasks + try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + assertFalse(buffer.isFinished()); + buffer.noMoreTasks(); + assertTrue(buffer.isFinished()); + } + + // single task producing no results, finish before noMoreTasks + try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + assertFalse(buffer.isFinished()); + + TaskId taskId = createTaskId(0, 0); + buffer.addTask(taskId); + assertFalse(buffer.isFinished()); + + buffer.taskFinished(taskId); + assertFalse(buffer.isFinished()); + + buffer.noMoreTasks(); + assertTrue(buffer.isFinished()); + } + + // single task producing no results, finish after noMoreTasks + try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + assertFalse(buffer.isFinished()); + + TaskId taskId = createTaskId(0, 0); + buffer.addTask(taskId); + assertFalse(buffer.isFinished()); + + buffer.noMoreTasks(); + assertFalse(buffer.isFinished()); + + buffer.taskFinished(taskId); + assertTrue(buffer.isFinished()); + } + + // single task producing no results, fail before noMoreTasks + try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + assertFalse(buffer.isFinished()); + + TaskId taskId = createTaskId(0, 0); + buffer.addTask(taskId); + assertFalse(buffer.isFinished()); + + buffer.taskFailed(taskId, new RuntimeException()); + assertFalse(buffer.isFinished()); + + buffer.noMoreTasks(); + assertTrue(buffer.isFinished()); + } + + // single task producing no results, fail after noMoreTasks + try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + assertFalse(buffer.isFinished()); + + TaskId taskId = createTaskId(0, 0); + buffer.addTask(taskId); + assertFalse(buffer.isFinished()); + + buffer.noMoreTasks(); + assertFalse(buffer.isFinished()); + + buffer.taskFailed(taskId, new RuntimeException()); + assertTrue(buffer.isFinished()); + } + + // single task producing one page, fail after noMoreTasks + try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + assertFalse(buffer.isFinished()); + + TaskId taskId = createTaskId(0, 0); + buffer.addTask(taskId); + buffer.addPages(taskId, ImmutableList.of(createPage("page"))); + assertFalse(buffer.isFinished()); + + buffer.noMoreTasks(); + assertFalse(buffer.isFinished()); + + buffer.taskFailed(taskId, new RuntimeException()); + assertTrue(buffer.isFinished()); + } + + // single task producing one page, finish after noMoreTasks + try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + assertFalse(buffer.isFinished()); + + TaskId taskId = createTaskId(0, 0); + buffer.addTask(taskId); + buffer.addPages(taskId, ImmutableList.of(createPage("page"))); + assertFalse(buffer.isFinished()); + + buffer.noMoreTasks(); + assertFalse(buffer.isFinished()); + + buffer.taskFinished(taskId); + assertFalse(buffer.isFinished()); + + assertNotNull(buffer.pollPage()); + assertTrue(buffer.isFinished()); + } + + // single task producing one page, finish before noMoreTasks + try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + assertFalse(buffer.isFinished()); + + TaskId taskId = createTaskId(0, 0); + buffer.addTask(taskId); + buffer.addPages(taskId, ImmutableList.of(createPage("page"))); + assertFalse(buffer.isFinished()); + + buffer.taskFinished(taskId); + assertFalse(buffer.isFinished()); + + buffer.noMoreTasks(); + assertFalse(buffer.isFinished()); + + assertNotNull(buffer.pollPage()); + assertTrue(buffer.isFinished()); + } + } + + @Test + public void testRemainingBufferCapacity() + { + try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + assertFalse(buffer.isFinished()); + + TaskId taskId = createTaskId(0, 0); + buffer.addTask(taskId); + SerializedPage page = createPage("page"); + buffer.addPages(taskId, ImmutableList.of(page)); + + assertEquals(buffer.getRemainingCapacityInBytes(), ONE_KB.toBytes() - page.getRetainedSizeInBytes()); + } + } + + private static TaskId createTaskId(int partition, int attempt) + { + return new TaskId(new StageId("query", 0), partition, attempt); + } + + private static SerializedPage createPage(String value) + { + return new SerializedPage(utf8Slice(value), PageCodecMarker.MarkerSet.empty(), 1, value.length()); + } + + private static void assertNotBlocked(ListenableFuture blocked) + { + assertTrue(blocked.isDone()); + } + + private static void assertBlocked(ListenableFuture blocked) + { + assertFalse(blocked.isDone()); + } +} diff --git a/core/trino-main/src/test/java/io/trino/operator/TestExchangeClient.java b/core/trino-main/src/test/java/io/trino/operator/TestExchangeClient.java index 36be978b6929..430aa6c89e1b 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestExchangeClient.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestExchangeClient.java @@ -13,7 +13,9 @@ */ package io.trino.operator; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.ListMultimap; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; @@ -27,37 +29,52 @@ import io.airlift.units.Duration; import io.trino.FeaturesConfig.DataIntegrityVerification; import io.trino.block.BlockAssertions; +import io.trino.execution.StageId; +import io.trino.execution.TaskId; +import io.trino.execution.buffer.PageCodecMarker; import io.trino.execution.buffer.PagesSerde; import io.trino.execution.buffer.SerializedPage; import io.trino.memory.context.SimpleLocalMemoryContext; import io.trino.spi.Page; import io.trino.spi.TrinoException; +import io.trino.spi.TrinoTransportException; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; import java.net.URI; +import java.util.ArrayList; +import java.util.List; import java.util.Map; +import java.util.Set; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableListMultimap.toImmutableListMultimap; import static com.google.common.collect.Maps.uniqueIndex; +import static com.google.common.collect.Sets.newConcurrentHashSet; import static com.google.common.io.ByteStreams.toByteArray; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static com.google.common.util.concurrent.Uninterruptibles.sleepUninterruptibly; import static io.airlift.concurrent.MoreFutures.tryGetFutureValue; import static io.airlift.concurrent.Threads.daemonThreadsNamed; +import static io.airlift.slice.Slices.utf8Slice; import static io.airlift.testing.Assertions.assertLessThan; import static io.trino.execution.buffer.TestingPagesSerdeFactory.testingPagesSerde; import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; +import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; +import static io.trino.testing.assertions.Assert.assertEventually; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; @@ -95,6 +112,72 @@ public void tearDown() @Test public void testHappyPath() + throws Exception + { + DataSize maxResponseSize = DataSize.of(10, Unit.MEGABYTE); + MockExchangeRequestProcessor processor = new MockExchangeRequestProcessor(maxResponseSize); + + List pages = ImmutableList.of(createSerializedPage("val1"), createSerializedPage("value2"), createSerializedPage("valllue3")); + + URI location = URI.create("http://localhost:8080"); + pages.forEach(page -> processor.addPage(location, page)); + processor.setComplete(location); + + TestingExchangeClientBuffer buffer = new TestingExchangeClientBuffer(DataSize.of(1, Unit.MEGABYTE)); + + @SuppressWarnings("resource") + ExchangeClient exchangeClient = new ExchangeClient( + "localhost", + DataIntegrityVerification.ABORT, + buffer, + maxResponseSize, + 1, + new Duration(1, TimeUnit.MINUTES), + true, + new TestingHttpClient(processor, scheduler), + scheduler, + new SimpleLocalMemoryContext(newSimpleAggregatedMemoryContext(), "test"), + pageBufferClientCallbackExecutor, + (taskId, failure) -> {}); + + assertThat(buffer.getAllTasks()).isEmpty(); + assertThat(buffer.getPages().asMap()).isEmpty(); + assertThat(buffer.getFinishedTasks()).isEmpty(); + assertThat(buffer.getFailedTasks().asMap()).isEmpty(); + assertFalse(buffer.isNoMoreTasks()); + + TaskId taskId = new TaskId(new StageId("query", 1), 0, 0); + exchangeClient.addLocation(taskId, location); + assertThat(buffer.getAllTasks()).containsExactly(taskId); + exchangeClient.noMoreLocations(); + assertTrue(buffer.isNoMoreTasks()); + + buffer.whenTaskFinished(taskId).get(10, SECONDS); + assertThat(buffer.getFinishedTasks()).containsExactly(taskId); + assertThat(buffer.getPages().get(taskId)).hasSize(3); + assertThat(buffer.getFailedTasks().asMap()).isEmpty(); + + assertFalse(exchangeClient.isFinished()); + buffer.setFinished(true); + assertTrue(exchangeClient.isFinished()); + + ExchangeClientStatus status = exchangeClient.getStatus(); + assertEquals(status.getBufferedPages(), 0); + + // client should have sent only 3 requests: one to get all pages, one to acknowledge and one to get the done signal + assertStatus(status.getPageBufferClientStatuses().get(0), location, "closed", 3, 3, 3, "not scheduled"); + + exchangeClient.close(); + + assertEventually(() -> assertEquals(exchangeClient.getStatus().getPageBufferClientStatuses().get(0).getHttpRequestState(), "not scheduled", "httpRequestState")); + + assertThat(buffer.getFinishedTasks()).containsExactly(taskId); + assertThat(buffer.getFailedTasks().asMap()).isEmpty(); + assertThat(buffer.getPages().size()).isEqualTo(3); + } + + @Test + public void testStreamingHappyPath() { DataSize maxResponseSize = DataSize.of(10, Unit.MEGABYTE); MockExchangeRequestProcessor processor = new MockExchangeRequestProcessor(maxResponseSize); @@ -109,7 +192,7 @@ public void testHappyPath() ExchangeClient exchangeClient = new ExchangeClient( "localhost", DataIntegrityVerification.ABORT, - DataSize.of(32, Unit.MEGABYTE), + new StreamingExchangeClientBuffer(scheduler, DataSize.of(32, Unit.MEGABYTE)), maxResponseSize, 1, new Duration(1, TimeUnit.MINUTES), @@ -117,40 +200,123 @@ public void testHappyPath() new TestingHttpClient(processor, scheduler), scheduler, new SimpleLocalMemoryContext(newSimpleAggregatedMemoryContext(), "test"), - pageBufferClientCallbackExecutor); + pageBufferClientCallbackExecutor, + (taskId, failure) -> {}); - exchangeClient.addLocation(location); + exchangeClient.addLocation(new TaskId(new StageId("query", 1), 0, 0), location); exchangeClient.noMoreLocations(); - assertEquals(exchangeClient.isClosed(), false); + assertFalse(exchangeClient.isFinished()); assertPageEquals(getNextPage(exchangeClient), createPage(1)); - assertEquals(exchangeClient.isClosed(), false); + assertFalse(exchangeClient.isFinished()); assertPageEquals(getNextPage(exchangeClient), createPage(2)); - assertEquals(exchangeClient.isClosed(), false); + assertFalse(exchangeClient.isFinished()); assertPageEquals(getNextPage(exchangeClient), createPage(3)); assertNull(getNextPage(exchangeClient)); - assertEquals(exchangeClient.isClosed(), true); + assertTrue(exchangeClient.isFinished()); ExchangeClientStatus status = exchangeClient.getStatus(); assertEquals(status.getBufferedPages(), 0); - assertEquals(status.getBufferedBytes(), 0); - // client should have sent only 2 requests: one to get all pages and once to get the done signal + // client should have sent only 3 requests: one to get all pages, one to acknowledge and one to get the done signal assertStatus(status.getPageBufferClientStatuses().get(0), location, "closed", 3, 3, 3, "not scheduled"); + + exchangeClient.close(); } - @Test(timeOut = 10000) + @Test public void testAddLocation() throws Exception { DataSize maxResponseSize = DataSize.of(10, Unit.MEGABYTE); MockExchangeRequestProcessor processor = new MockExchangeRequestProcessor(maxResponseSize); + TaskId task1 = new TaskId(new StageId("query", 1), 0, 0); + TaskId task2 = new TaskId(new StageId("query", 1), 1, 0); + TaskId task3 = new TaskId(new StageId("query", 1), 2, 0); + + URI location1 = URI.create("http://localhost:8080/1"); + URI location2 = URI.create("http://localhost:8080/2"); + URI location3 = URI.create("http://localhost:8080/3"); + + processor.addPage(location1, createSerializedPage("location-1-page-1")); + processor.addPage(location1, createSerializedPage("location-1-page-2")); + + TestingExchangeClientBuffer buffer = new TestingExchangeClientBuffer(DataSize.of(1, Unit.MEGABYTE)); + + @SuppressWarnings("resource") + ExchangeClient exchangeClient = new ExchangeClient( + "localhost", + DataIntegrityVerification.ABORT, + buffer, + maxResponseSize, + 1, + new Duration(1, TimeUnit.MINUTES), + true, + new TestingHttpClient(processor, scheduler), + scheduler, + new SimpleLocalMemoryContext(newSimpleAggregatedMemoryContext(), "test"), + pageBufferClientCallbackExecutor, + (taskId, failure) -> {}); + + assertThat(buffer.getAllTasks()).isEmpty(); + assertThat(buffer.getPages().asMap()).isEmpty(); + assertThat(buffer.getFinishedTasks()).isEmpty(); + assertThat(buffer.getFailedTasks().asMap()).isEmpty(); + assertFalse(buffer.isNoMoreTasks()); + + exchangeClient.addLocation(task1, location1); + assertThat(buffer.getAllTasks()).containsExactly(task1); + assertTaskIsNotFinished(buffer, task1); + + processor.setComplete(location1); + buffer.whenTaskFinished(task1).get(10, SECONDS); + assertThat(buffer.getPages().get(task1)).hasSize(2); + assertThat(buffer.getFinishedTasks()).containsExactly(task1); + + exchangeClient.addLocation(task2, location2); + assertThat(buffer.getAllTasks()).containsExactlyInAnyOrder(task1, task2); + assertTaskIsNotFinished(buffer, task2); + + processor.setComplete(location2); + buffer.whenTaskFinished(task2).get(10, SECONDS); + assertThat(buffer.getFinishedTasks()).containsExactlyInAnyOrder(task1, task2); + assertThat(buffer.getPages().get(task2)).hasSize(0); + + exchangeClient.addLocation(task3, location3); + assertThat(buffer.getAllTasks()).containsExactlyInAnyOrder(task1, task2, task3); + assertTaskIsNotFinished(buffer, task3); + + exchangeClient.noMoreLocations(); + assertTrue(buffer.isNoMoreTasks()); + + assertThat(buffer.getAllTasks()).containsExactlyInAnyOrder(task1, task2, task3); + assertTaskIsNotFinished(buffer, task3); + + exchangeClient.close(); + + assertEventually(() -> assertEquals(exchangeClient.getStatus().getPageBufferClientStatuses().get(0).getHttpRequestState(), "not scheduled", "httpRequestState")); + assertEventually(() -> assertEquals(exchangeClient.getStatus().getPageBufferClientStatuses().get(1).getHttpRequestState(), "not scheduled", "httpRequestState")); + assertEventually(() -> assertEquals(exchangeClient.getStatus().getPageBufferClientStatuses().get(2).getHttpRequestState(), "not scheduled", "httpRequestState")); + + assertThat(buffer.getFinishedTasks()).containsExactlyInAnyOrder(task1, task2, task3); + assertThat(buffer.getFailedTasks().asMap()).isEmpty(); + + assertTrue(exchangeClient.isFinished()); + } + + @Test(timeOut = 10000) + public void testStreamingAddLocation() + throws Exception + { + DataSize maxResponseSize = DataSize.of(10, Unit.MEGABYTE); + MockExchangeRequestProcessor processor = new MockExchangeRequestProcessor(maxResponseSize); + @SuppressWarnings("resource") ExchangeClient exchangeClient = new ExchangeClient( "localhost", DataIntegrityVerification.ABORT, - DataSize.of(32, Unit.MEGABYTE), + new StreamingExchangeClientBuffer(scheduler, DataSize.of(32, Unit.MEGABYTE)), maxResponseSize, 1, new Duration(1, TimeUnit.MINUTES), @@ -158,48 +324,76 @@ public void testAddLocation() new TestingHttpClient(processor, newCachedThreadPool(daemonThreadsNamed(getClass().getSimpleName() + "-testAddLocation-%s"))), scheduler, new SimpleLocalMemoryContext(newSimpleAggregatedMemoryContext(), "test"), - pageBufferClientCallbackExecutor); + pageBufferClientCallbackExecutor, + (taskId, failure) -> {}); URI location1 = URI.create("http://localhost:8081/foo"); processor.addPage(location1, createPage(1)); processor.addPage(location1, createPage(2)); processor.addPage(location1, createPage(3)); processor.setComplete(location1); - exchangeClient.addLocation(location1); + exchangeClient.addLocation(new TaskId(new StageId("query", 1), 0, 0), location1); - assertEquals(exchangeClient.isClosed(), false); + assertFalse(exchangeClient.isFinished()); assertPageEquals(getNextPage(exchangeClient), createPage(1)); - assertEquals(exchangeClient.isClosed(), false); + assertFalse(exchangeClient.isFinished()); assertPageEquals(getNextPage(exchangeClient), createPage(2)); - assertEquals(exchangeClient.isClosed(), false); + assertFalse(exchangeClient.isFinished()); assertPageEquals(getNextPage(exchangeClient), createPage(3)); - assertFalse(tryGetFutureValue(exchangeClient.isBlocked(), 10, MILLISECONDS).isPresent()); - assertEquals(exchangeClient.isClosed(), false); + assertNull(exchangeClient.pollPage()); + ListenableFuture firstBlocked = exchangeClient.isBlocked(); + assertFalse(tryGetFutureValue(firstBlocked, 10, MILLISECONDS).isPresent()); + assertFalse(firstBlocked.isDone()); + + assertNull(exchangeClient.pollPage()); + ListenableFuture secondBlocked = exchangeClient.isBlocked(); + assertFalse(tryGetFutureValue(secondBlocked, 10, MILLISECONDS).isPresent()); + assertFalse(secondBlocked.isDone()); + + assertNull(exchangeClient.pollPage()); + ListenableFuture thirdBlocked = exchangeClient.isBlocked(); + assertFalse(tryGetFutureValue(thirdBlocked, 10, MILLISECONDS).isPresent()); + assertFalse(thirdBlocked.isDone()); + + thirdBlocked.cancel(true); + assertTrue(thirdBlocked.isDone()); + assertFalse(tryGetFutureValue(firstBlocked, 10, MILLISECONDS).isPresent()); + assertFalse(firstBlocked.isDone()); + assertFalse(tryGetFutureValue(secondBlocked, 10, MILLISECONDS).isPresent()); + assertFalse(secondBlocked.isDone()); + + assertFalse(exchangeClient.isFinished()); URI location2 = URI.create("http://localhost:8082/bar"); processor.addPage(location2, createPage(4)); processor.addPage(location2, createPage(5)); processor.addPage(location2, createPage(6)); processor.setComplete(location2); - exchangeClient.addLocation(location2); + exchangeClient.addLocation(new TaskId(new StageId("query", 1), 1, 0), location2); - assertEquals(exchangeClient.isClosed(), false); + tryGetFutureValue(firstBlocked, 5, SECONDS); + assertTrue(firstBlocked.isDone()); + tryGetFutureValue(secondBlocked, 5, SECONDS); + assertTrue(secondBlocked.isDone()); + + assertFalse(exchangeClient.isFinished()); assertPageEquals(getNextPage(exchangeClient), createPage(4)); - assertEquals(exchangeClient.isClosed(), false); + assertFalse(exchangeClient.isFinished()); assertPageEquals(getNextPage(exchangeClient), createPage(5)); - assertEquals(exchangeClient.isClosed(), false); + assertFalse(exchangeClient.isFinished()); assertPageEquals(getNextPage(exchangeClient), createPage(6)); assertFalse(tryGetFutureValue(exchangeClient.isBlocked(), 10, MILLISECONDS).isPresent()); - assertEquals(exchangeClient.isClosed(), false); + assertFalse(exchangeClient.isFinished()); exchangeClient.noMoreLocations(); // The transition to closed may happen asynchronously, since it requires that all the HTTP clients // receive a final GONE response, so just spin until it's closed or the test times out. - while (!exchangeClient.isClosed()) { + while (!exchangeClient.isFinished()) { Thread.sleep(1); } + exchangeClient.close(); ImmutableMap statuses = uniqueIndex(exchangeClient.getStatus().getPageBufferClientStatuses(), PageBufferClientStatus::getUri); assertStatus(statuses.get(location1), location1, "closed", 3, 3, 3, "not scheduled"); @@ -207,7 +401,202 @@ public void testAddLocation() } @Test - public void testBufferLimit() + public void testDeduplication() + throws Exception + { + DataSize maxResponseSize = DataSize.of(10, Unit.MEGABYTE); + MockExchangeRequestProcessor processor = new MockExchangeRequestProcessor(maxResponseSize); + + TaskId taskP0A0 = new TaskId(new StageId("query", 1), 0, 0); + TaskId taskP1A0 = new TaskId(new StageId("query", 1), 1, 0); + TaskId taskP0A1 = new TaskId(new StageId("query", 1), 0, 1); + + URI locationP0A0 = URI.create("http://localhost:8080/1"); + URI locationP1A0 = URI.create("http://localhost:8080/2"); + URI locationP0A1 = URI.create("http://localhost:8080/3"); + + processor.addPage(locationP1A0, createSerializedPage("location-1-page-1")); + processor.addPage(locationP0A1, createSerializedPage("location-2-page-1")); + processor.addPage(locationP0A1, createSerializedPage("location-2-page-2")); + + @SuppressWarnings("resource") + ExchangeClient exchangeClient = new ExchangeClient( + "localhost", + DataIntegrityVerification.ABORT, + new DeduplicationExchangeClientBuffer(scheduler, DataSize.of(1, Unit.KILOBYTE), RetryPolicy.QUERY), + maxResponseSize, + 1, + new Duration(1, SECONDS), + true, + new TestingHttpClient(processor, scheduler), + scheduler, + new SimpleLocalMemoryContext(newSimpleAggregatedMemoryContext(), "test"), + pageBufferClientCallbackExecutor, + (taskId, failure) -> {}); + + exchangeClient.addLocation(taskP0A0, locationP0A0); + exchangeClient.addLocation(taskP1A0, locationP1A0); + exchangeClient.addLocation(taskP0A1, locationP0A1); + + processor.setComplete(locationP0A0); + // Failing attempt 0. Results from all tasks for attempt 0 must be discarded. + processor.setFailed(locationP1A0, new RuntimeException("failure")); + processor.setComplete(locationP0A1); + + assertFalse(exchangeClient.isFinished()); + assertThatThrownBy(() -> exchangeClient.isBlocked().get(50, MILLISECONDS)) + .isInstanceOf(TimeoutException.class); + + exchangeClient.noMoreLocations(); + exchangeClient.isBlocked().get(10, SECONDS); + + List pageValues = new ArrayList<>(); + while (!exchangeClient.isFinished()) { + SerializedPage page = exchangeClient.pollPage(); + if (page == null) { + break; + } + pageValues.add(page.getSlice().toStringUtf8()); + } + + assertThat(pageValues).containsExactlyInAnyOrder("location-2-page-1", "location-2-page-2"); + assertEventually(() -> assertTrue(exchangeClient.isFinished())); + + assertEventually(() -> { + assertEquals(exchangeClient.getStatus().getPageBufferClientStatuses().get(0).getHttpRequestState(), "not scheduled", "httpRequestState"); + assertEquals(exchangeClient.getStatus().getPageBufferClientStatuses().get(1).getHttpRequestState(), "not scheduled", "httpRequestState"); + assertEquals(exchangeClient.getStatus().getPageBufferClientStatuses().get(2).getHttpRequestState(), "not scheduled", "httpRequestState"); + }); + + exchangeClient.close(); + } + + @Test + public void testTaskFailure() + throws Exception + { + DataSize maxResponseSize = DataSize.of(10, Unit.MEGABYTE); + MockExchangeRequestProcessor processor = new MockExchangeRequestProcessor(maxResponseSize); + + TaskId task1 = new TaskId(new StageId("query", 1), 0, 0); + TaskId task2 = new TaskId(new StageId("query", 1), 1, 0); + TaskId task3 = new TaskId(new StageId("query", 1), 2, 0); + TaskId task4 = new TaskId(new StageId("query", 1), 3, 0); + + URI location1 = URI.create("http://localhost:8080/1"); + URI location2 = URI.create("http://localhost:8080/2"); + URI location3 = URI.create("http://localhost:8080/3"); + URI location4 = URI.create("http://localhost:8080/4"); + + processor.addPage(location1, createSerializedPage("location-1-page-1")); + processor.addPage(location4, createSerializedPage("location-4-page-1")); + processor.addPage(location4, createSerializedPage("location-4-page-2")); + + TestingExchangeClientBuffer buffer = new TestingExchangeClientBuffer(DataSize.of(1, Unit.MEGABYTE)); + + Set failedTasks = newConcurrentHashSet(); + CountDownLatch latch = new CountDownLatch(2); + + @SuppressWarnings("resource") + ExchangeClient exchangeClient = new ExchangeClient( + "localhost", + DataIntegrityVerification.ABORT, + buffer, + maxResponseSize, + 1, + new Duration(1, SECONDS), + true, + new TestingHttpClient(processor, scheduler), + scheduler, + new SimpleLocalMemoryContext(newSimpleAggregatedMemoryContext(), "test"), + pageBufferClientCallbackExecutor, + (taskId, failure) -> { + failedTasks.add(taskId); + latch.countDown(); + }); + + assertThat(buffer.getAllTasks()).isEmpty(); + assertThat(buffer.getPages().asMap()).isEmpty(); + assertThat(buffer.getFinishedTasks()).isEmpty(); + assertThat(buffer.getFailedTasks().asMap()).isEmpty(); + assertFalse(buffer.isNoMoreTasks()); + + exchangeClient.addLocation(task1, location1); + assertThat(buffer.getAllTasks()).containsExactly(task1); + assertTaskIsNotFinished(buffer, task1); + + processor.setComplete(location1); + buffer.whenTaskFinished(task1).get(10, SECONDS); + assertThat(buffer.getPages().get(task1)).hasSize(1); + assertThat(buffer.getFinishedTasks()).containsExactly(task1); + + exchangeClient.addLocation(task2, location2); + assertThat(buffer.getAllTasks()).containsExactlyInAnyOrder(task1, task2); + assertTaskIsNotFinished(buffer, task2); + + RuntimeException randomException = new RuntimeException("randomfailure"); + processor.setFailed(location2, randomException); + buffer.whenTaskFailed(task2).get(10, SECONDS); + + assertThat(buffer.getFinishedTasks()).containsExactly(task1); + assertThat(buffer.getFailedTasks().keySet()).containsExactly(task2); + assertThat(buffer.getPages().get(task2)).hasSize(0); + + exchangeClient.addLocation(task3, location3); + assertThat(buffer.getAllTasks()).containsExactlyInAnyOrder(task1, task2, task3); + assertTaskIsNotFinished(buffer, task2); + assertTaskIsNotFinished(buffer, task3); + + TrinoException trinoException = new TrinoException(GENERIC_INTERNAL_ERROR, "generic internal error"); + processor.setFailed(location3, trinoException); + buffer.whenTaskFailed(task3).get(10, SECONDS); + + assertThat(buffer.getFinishedTasks()).containsExactly(task1); + assertThat(buffer.getFailedTasks().keySet()).containsExactlyInAnyOrder(task2, task3); + assertThat(buffer.getPages().get(task2)).hasSize(0); + assertThat(buffer.getPages().get(task3)).hasSize(0); + + assertTrue(latch.await(10, SECONDS)); + assertEquals(failedTasks, ImmutableSet.of(task2, task3)); + + exchangeClient.addLocation(task4, location4); + assertThat(buffer.getAllTasks()).containsExactlyInAnyOrder(task1, task2, task3, task4); + assertTaskIsNotFinished(buffer, task4); + + processor.setComplete(location4); + buffer.whenTaskFinished(task4).get(10, SECONDS); + assertThat(buffer.getPages().get(task4)).hasSize(2); + assertThat(buffer.getFinishedTasks()).containsExactlyInAnyOrder(task1, task4); + + assertFalse(exchangeClient.isFinished()); + buffer.setFinished(true); + assertTrue(exchangeClient.isFinished()); + + exchangeClient.close(); + + assertEventually(() -> assertEquals(exchangeClient.getStatus().getPageBufferClientStatuses().get(0).getHttpRequestState(), "not scheduled", "httpRequestState")); + assertEventually(() -> assertEquals(exchangeClient.getStatus().getPageBufferClientStatuses().get(1).getHttpRequestState(), "not scheduled", "httpRequestState")); + assertEventually(() -> assertEquals(exchangeClient.getStatus().getPageBufferClientStatuses().get(2).getHttpRequestState(), "not scheduled", "httpRequestState")); + assertEventually(() -> assertEquals(exchangeClient.getStatus().getPageBufferClientStatuses().get(3).getHttpRequestState(), "not scheduled", "httpRequestState")); + + assertThat(buffer.getFinishedTasks()).containsExactlyInAnyOrder(task1, task4); + assertThat(buffer.getFailedTasks().keySet()).containsExactlyInAnyOrder(task2, task3); + assertThat(buffer.getFailedTasks().asMap().get(task2)).hasSize(1); + assertThat(buffer.getFailedTasks().asMap().get(task2).iterator().next()).isInstanceOf(TrinoTransportException.class); + assertThat(buffer.getFailedTasks().asMap().get(task3)).hasSize(1); + assertThat(buffer.getFailedTasks().asMap().get(task3).iterator().next()).isEqualTo(trinoException); + + assertTrue(exchangeClient.isFinished()); + } + + private static void assertTaskIsNotFinished(TestingExchangeClientBuffer buffer, TaskId task) + { + assertThatThrownBy(() -> buffer.whenTaskFinished(task).get(50, MILLISECONDS)) + .isInstanceOf(TimeoutException.class); + } + + @Test + public void testStreamingBufferLimit() { DataSize maxResponseSize = DataSize.ofBytes(1); MockExchangeRequestProcessor processor = new MockExchangeRequestProcessor(maxResponseSize); @@ -224,7 +613,7 @@ public void testBufferLimit() ExchangeClient exchangeClient = new ExchangeClient( "localhost", DataIntegrityVerification.ABORT, - DataSize.ofBytes(1), + new StreamingExchangeClientBuffer(scheduler, DataSize.ofBytes(1)), maxResponseSize, 1, new Duration(1, TimeUnit.MINUTES), @@ -232,11 +621,12 @@ public void testBufferLimit() new TestingHttpClient(processor, newCachedThreadPool(daemonThreadsNamed(getClass().getSimpleName() + "-testBufferLimit-%s"))), scheduler, new SimpleLocalMemoryContext(newSimpleAggregatedMemoryContext(), "test"), - pageBufferClientCallbackExecutor); + pageBufferClientCallbackExecutor, + (taskId, failure) -> {}); - exchangeClient.addLocation(location); + exchangeClient.addLocation(new TaskId(new StageId("query", 1), 0, 0), location); exchangeClient.noMoreLocations(); - assertEquals(exchangeClient.isClosed(), false); + assertFalse(exchangeClient.isFinished()); long start = System.nanoTime(); @@ -286,40 +676,37 @@ public void testBufferLimit() assertNull(getNextPage(exchangeClient)); assertEquals(exchangeClient.getStatus().getBufferedPages(), 0); assertTrue(exchangeClient.getStatus().getBufferedBytes() == 0); - assertEquals(exchangeClient.isClosed(), true); + assertEquals(exchangeClient.isFinished(), true); + exchangeClient.close(); assertStatus(exchangeClient.getStatus().getPageBufferClientStatuses().get(0), location, "closed", 3, 5, 5, "not scheduled"); } @Test - public void testAbortOnDataCorruption() + public void testStreamingAbortOnDataCorruption() { URI location = URI.create("http://localhost:8080"); ExchangeClient exchangeClient = setUpDataCorruption(DataIntegrityVerification.ABORT, location); - assertFalse(exchangeClient.isClosed()); assertThatThrownBy(() -> getNextPage(exchangeClient)) .isInstanceOf(TrinoException.class) .hasMessageMatching("Checksum verification failure on localhost when reading from http://localhost:8080/0: Data corruption, read checksum: 0xf91cfe5d2bc6e1c2, calculated checksum: 0x3c51297c7b78052f"); - assertThatThrownBy(exchangeClient::isFinished) - .isInstanceOf(TrinoException.class) - .hasMessageMatching("Checksum verification failure on localhost when reading from http://localhost:8080/0: Data corruption, read checksum: 0xf91cfe5d2bc6e1c2, calculated checksum: 0x3c51297c7b78052f"); - exchangeClient.close(); } @Test - public void testRetryDataCorruption() + public void testStreamingRetryDataCorruption() { URI location = URI.create("http://localhost:8080"); ExchangeClient exchangeClient = setUpDataCorruption(DataIntegrityVerification.RETRY, location); - assertFalse(exchangeClient.isClosed()); + assertFalse(exchangeClient.isFinished()); assertPageEquals(getNextPage(exchangeClient), createPage(1)); - assertFalse(exchangeClient.isClosed()); + assertFalse(exchangeClient.isFinished()); assertPageEquals(getNextPage(exchangeClient), createPage(2)); assertNull(getNextPage(exchangeClient)); - assertTrue(exchangeClient.isClosed()); + assertTrue(exchangeClient.isFinished()); + exchangeClient.close(); ExchangeClientStatus status = exchangeClient.getStatus(); assertEquals(status.getBufferedPages(), 0); @@ -377,7 +764,7 @@ public synchronized Response handle(Request request) ExchangeClient exchangeClient = new ExchangeClient( "localhost", dataIntegrityVerification, - DataSize.of(32, Unit.MEGABYTE), + new StreamingExchangeClientBuffer(scheduler, DataSize.of(32, Unit.MEGABYTE)), maxResponseSize, 1, new Duration(1, TimeUnit.MINUTES), @@ -385,16 +772,17 @@ public synchronized Response handle(Request request) new TestingHttpClient(processor, scheduler), scheduler, new SimpleLocalMemoryContext(newSimpleAggregatedMemoryContext(), "test"), - pageBufferClientCallbackExecutor); + pageBufferClientCallbackExecutor, + (taskId, failure) -> {}); - exchangeClient.addLocation(location); + exchangeClient.addLocation(new TaskId(new StageId("query", 1), 0, 0), location); exchangeClient.noMoreLocations(); return exchangeClient; } @Test - public void testClose() + public void testStreamingClose() throws Exception { DataSize maxResponseSize = DataSize.ofBytes(1); @@ -409,7 +797,7 @@ public void testClose() ExchangeClient exchangeClient = new ExchangeClient( "localhost", DataIntegrityVerification.ABORT, - DataSize.ofBytes(1), + new StreamingExchangeClientBuffer(scheduler, DataSize.ofBytes(1)), maxResponseSize, 1, new Duration(1, TimeUnit.MINUTES), @@ -417,12 +805,13 @@ public void testClose() new TestingHttpClient(processor, newCachedThreadPool(daemonThreadsNamed(getClass().getSimpleName() + "-testClose-%s"))), scheduler, new SimpleLocalMemoryContext(newSimpleAggregatedMemoryContext(), "test"), - pageBufferClientCallbackExecutor); - exchangeClient.addLocation(location); + pageBufferClientCallbackExecutor, + (taskId, failure) -> {}); + exchangeClient.addLocation(new TaskId(new StageId("query", 1), 0, 0), location); exchangeClient.noMoreLocations(); // fetch a page - assertEquals(exchangeClient.isClosed(), false); + assertFalse(exchangeClient.isFinished()); assertPageEquals(getNextPage(exchangeClient), createPage(1)); // close client while pages are still available @@ -430,12 +819,10 @@ public void testClose() while (!exchangeClient.isFinished()) { MILLISECONDS.sleep(10); } - assertEquals(exchangeClient.isClosed(), true); + assertTrue(exchangeClient.isFinished()); assertNull(exchangeClient.pollPage()); assertEquals(exchangeClient.getStatus().getBufferedPages(), 0); - assertEquals(exchangeClient.getStatus().getBufferedBytes(), 0); - // client should have sent only 2 requests: one to get all pages and once to get the done signal PageBufferClientStatus clientStatus = exchangeClient.getStatus().getPageBufferClientStatuses().get(0); assertEquals(clientStatus.getUri(), location); assertEquals(clientStatus.getState(), "closed", "status"); @@ -447,6 +834,11 @@ private static Page createPage(int size) return new Page(BlockAssertions.createLongSequenceBlock(0, size)); } + private static SerializedPage createSerializedPage(String value) + { + return new SerializedPage(utf8Slice(value), PageCodecMarker.MarkerSet.empty(), 1, value.length()); + } + private static SerializedPage getNextPage(ExchangeClient exchangeClient) { ListenableFuture futurePage = Futures.transform(exchangeClient.isBlocked(), ignored -> exchangeClient.pollPage(), directExecutor()); diff --git a/core/trino-main/src/test/java/io/trino/operator/TestExchangeOperator.java b/core/trino-main/src/test/java/io/trino/operator/TestExchangeOperator.java index 8008ebedf83c..19057f243494 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestExchangeOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestExchangeOperator.java @@ -23,6 +23,8 @@ import io.airlift.units.Duration; import io.trino.FeaturesConfig.DataIntegrityVerification; import io.trino.execution.Lifespan; +import io.trino.execution.StageId; +import io.trino.execution.TaskId; import io.trino.execution.buffer.PagesSerdeFactory; import io.trino.execution.buffer.TestingPagesSerdeFactory; import io.trino.metadata.Split; @@ -64,11 +66,11 @@ public class TestExchangeOperator private static final List TYPES = ImmutableList.of(VARCHAR); private static final PagesSerdeFactory SERDE_FACTORY = new TestingPagesSerdeFactory(); - private static final String TASK_1_ID = "task1"; - private static final String TASK_2_ID = "task2"; - private static final String TASK_3_ID = "task3"; + private static final TaskId TASK_1_ID = new TaskId(new StageId("query", 0), 0, 0); + private static final TaskId TASK_2_ID = new TaskId(new StageId("query", 0), 1, 0); + private static final TaskId TASK_3_ID = new TaskId(new StageId("query", 0), 2, 0); - private final LoadingCache taskBuffers = CacheBuilder.newBuilder().build(CacheLoader.from(TestingTaskBuffer::new)); + private final LoadingCache taskBuffers = CacheBuilder.newBuilder().build(CacheLoader.from(TestingTaskBuffer::new)); private ScheduledExecutorService scheduler; private ScheduledExecutorService scheduledExecutor; @@ -85,10 +87,10 @@ public void setUp() pageBufferClientCallbackExecutor = Executors.newSingleThreadExecutor(); httpClient = new TestingHttpClient(new TestingExchangeHttpClientHandler(taskBuffers), scheduler); - exchangeClientSupplier = (systemMemoryUsageListener) -> new ExchangeClient( + exchangeClientSupplier = (systemMemoryUsageListener, taskFailureListener, retryPolicy) -> new ExchangeClient( "localhost", DataIntegrityVerification.ABORT, - DataSize.of(32, MEGABYTE), + new StreamingExchangeClientBuffer(scheduler, DataSize.of(32, MEGABYTE)), DataSize.of(10, MEGABYTE), 3, new Duration(1, TimeUnit.MINUTES), @@ -96,7 +98,8 @@ public void setUp() httpClient, scheduler, systemMemoryUsageListener, - pageBufferClientCallbackExecutor); + pageBufferClientCallbackExecutor, + taskFailureListener); } @AfterClass(alwaysRun = true) @@ -144,9 +147,9 @@ public void testSimple() waitForFinished(operator); } - private static Split newRemoteSplit(String taskId) + private static Split newRemoteSplit(TaskId taskId) { - return new Split(REMOTE_CONNECTOR_ID, new RemoteSplit(URI.create("http://localhost/" + taskId)), Lifespan.taskWide()); + return new Split(REMOTE_CONNECTOR_ID, new RemoteSplit(taskId, URI.create("http://localhost/" + taskId)), Lifespan.taskWide()); } @Test @@ -251,7 +254,7 @@ public void testFinish() private SourceOperator createExchangeOperator() { - ExchangeOperatorFactory operatorFactory = new ExchangeOperatorFactory(0, new PlanNodeId("test"), exchangeClientSupplier, SERDE_FACTORY); + ExchangeOperatorFactory operatorFactory = new ExchangeOperatorFactory(0, new PlanNodeId("test"), exchangeClientSupplier, SERDE_FACTORY, RetryPolicy.NONE); DriverContext driverContext = createTaskContext(scheduler, scheduledExecutor, TEST_SESSION) .addPipelineContext(0, true, true, false) diff --git a/core/trino-main/src/test/java/io/trino/operator/TestHttpPageBufferClient.java b/core/trino-main/src/test/java/io/trino/operator/TestHttpPageBufferClient.java index 93ab3eab7005..ff0dcecba006 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestHttpPageBufferClient.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestHttpPageBufferClient.java @@ -24,6 +24,8 @@ import io.airlift.units.DataSize.Unit; import io.airlift.units.Duration; import io.trino.FeaturesConfig.DataIntegrityVerification; +import io.trino.execution.StageId; +import io.trino.execution.TaskId; import io.trino.execution.buffer.PagesSerde; import io.trino.execution.buffer.SerializedPage; import io.trino.operator.HttpPageBufferClient.ClientCallback; @@ -74,6 +76,7 @@ public class TestHttpPageBufferClient private ExecutorService pageBufferClientCallbackExecutor; private static final PagesSerde PAGES_SERDE = testingPagesSerde(); + private static final TaskId TASK_ID = new TaskId(new StageId("query", 0), 0, 0); @BeforeClass public void setUp() @@ -116,6 +119,7 @@ public void testHappyPath() expectedMaxSize, new Duration(1, TimeUnit.MINUTES), true, + TASK_ID, location, callback, scheduler, @@ -204,6 +208,7 @@ public void testLifecycle() DataSize.of(10, MEGABYTE), new Duration(1, TimeUnit.MINUTES), true, + TASK_ID, location, callback, scheduler, @@ -247,6 +252,7 @@ public void testInvalidResponses() DataSize.of(10, MEGABYTE), new Duration(1, TimeUnit.MINUTES), true, + TASK_ID, location, callback, scheduler, @@ -319,6 +325,7 @@ public void testCloseDuringPendingRequest() DataSize.of(10, MEGABYTE), new Duration(1, TimeUnit.MINUTES), true, + TASK_ID, location, callback, scheduler, @@ -376,6 +383,7 @@ public void testExceptionFromResponseHandler() DataSize.of(10, MEGABYTE), new Duration(30, TimeUnit.SECONDS), true, + TASK_ID, location, callback, scheduler, @@ -458,6 +466,7 @@ public boolean addPages(HttpPageBufferClient client, List pages) DataSize.of(10, MEGABYTE), new Duration(30, TimeUnit.SECONDS), true, + TASK_ID, location, callback, scheduler, diff --git a/core/trino-main/src/test/java/io/trino/operator/TestMergeOperator.java b/core/trino-main/src/test/java/io/trino/operator/TestMergeOperator.java index 75e23ae6caff..d80e354aa7bb 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestMergeOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestMergeOperator.java @@ -22,6 +22,8 @@ import io.airlift.node.NodeInfo; import io.trino.FeaturesConfig; import io.trino.execution.Lifespan; +import io.trino.execution.StageId; +import io.trino.execution.TaskId; import io.trino.execution.buffer.PagesSerdeFactory; import io.trino.execution.buffer.TestingPagesSerdeFactory; import io.trino.metadata.Split; @@ -63,9 +65,9 @@ @Test(singleThreaded = true) public class TestMergeOperator { - private static final String TASK_1_ID = "task1"; - private static final String TASK_2_ID = "task2"; - private static final String TASK_3_ID = "task3"; + private static final TaskId TASK_1_ID = new TaskId(new StageId("query", 0), 0, 0); + private static final TaskId TASK_2_ID = new TaskId(new StageId("query", 0), 1, 0); + private static final TaskId TASK_3_ID = new TaskId(new StageId("query", 0), 2, 0); private final AtomicInteger operatorId = new AtomicInteger(); @@ -75,7 +77,7 @@ public class TestMergeOperator private ExchangeClientFactory exchangeClientFactory; private OrderingCompiler orderingCompiler; - private LoadingCache taskBuffers; + private LoadingCache taskBuffers; @BeforeMethod public void setUp() @@ -351,9 +353,9 @@ private MergeOperator createMergeOperator(List sourceTypes, List return (MergeOperator) factory.createOperator(driverContext); } - private static Split createRemoteSplit(String taskId) + private static Split createRemoteSplit(TaskId taskId) { - return new Split(ExchangeOperator.REMOTE_CONNECTOR_ID, new RemoteSplit(URI.create("http://localhost/" + taskId)), Lifespan.taskWide()); + return new Split(ExchangeOperator.REMOTE_CONNECTOR_ID, new RemoteSplit(taskId, URI.create("http://localhost/" + taskId)), Lifespan.taskWide()); } private static List pullAvailablePages(Operator operator) diff --git a/core/trino-main/src/test/java/io/trino/operator/TestStreamingExchangeClientBuffer.java b/core/trino-main/src/test/java/io/trino/operator/TestStreamingExchangeClientBuffer.java new file mode 100644 index 000000000000..5b139a3e8ae0 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/operator/TestStreamingExchangeClientBuffer.java @@ -0,0 +1,212 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator; + +import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.ListenableFuture; +import io.airlift.units.DataSize; +import io.trino.execution.StageId; +import io.trino.execution.TaskId; +import io.trino.execution.buffer.PageCodecMarker; +import io.trino.execution.buffer.SerializedPage; +import io.trino.spi.QueryId; +import org.testng.annotations.Test; + +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static io.airlift.slice.Slices.utf8Slice; +import static io.airlift.units.DataSize.Unit.KILOBYTE; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNull; +import static org.testng.Assert.assertTrue; + +public class TestStreamingExchangeClientBuffer +{ + private static final StageId STAGE_ID = new StageId(new QueryId("query"), 0); + private static final TaskId TASK_0 = new TaskId(STAGE_ID, 0, 0); + private static final TaskId TASK_1 = new TaskId(STAGE_ID, 1, 0); + private static final SerializedPage PAGE_0 = createPage("page0"); + private static final SerializedPage PAGE_1 = createPage("page-1"); + private static final SerializedPage PAGE_2 = createPage("page-_2"); + + @Test + public void testHappyPath() + { + try (StreamingExchangeClientBuffer buffer = new StreamingExchangeClientBuffer(directExecutor(), DataSize.of(1, KILOBYTE))) { + assertFalse(buffer.isFinished()); + assertFalse(buffer.isBlocked().isDone()); + assertNull(buffer.pollPage()); + + buffer.addTask(TASK_0); + assertFalse(buffer.isFinished()); + assertFalse(buffer.isBlocked().isDone()); + assertNull(buffer.pollPage()); + + buffer.addTask(TASK_1); + assertFalse(buffer.isFinished()); + assertFalse(buffer.isBlocked().isDone()); + assertNull(buffer.pollPage()); + + buffer.noMoreTasks(); + assertFalse(buffer.isFinished()); + assertFalse(buffer.isBlocked().isDone()); + assertNull(buffer.pollPage()); + + buffer.addPages(TASK_0, ImmutableList.of(PAGE_0)); + assertEquals(buffer.getBufferedPageCount(), 1); + assertEquals(buffer.getRetainedSizeInBytes(), PAGE_0.getRetainedSizeInBytes()); + assertEquals(buffer.getMaxRetainedSizeInBytes(), PAGE_0.getRetainedSizeInBytes()); + assertEquals(buffer.getRemainingCapacityInBytes(), DataSize.of(1, KILOBYTE).toBytes() - PAGE_0.getRetainedSizeInBytes()); + assertFalse(buffer.isFinished()); + assertTrue(buffer.isBlocked().isDone()); + assertPageEquals(buffer.pollPage(), PAGE_0); + assertEquals(buffer.getRetainedSizeInBytes(), 0); + assertEquals(buffer.getMaxRetainedSizeInBytes(), PAGE_0.getRetainedSizeInBytes()); + assertEquals(buffer.getRemainingCapacityInBytes(), DataSize.of(1, KILOBYTE).toBytes()); + assertFalse(buffer.isFinished()); + assertFalse(buffer.isBlocked().isDone()); + + buffer.taskFinished(TASK_0); + assertFalse(buffer.isFinished()); + assertFalse(buffer.isBlocked().isDone()); + + buffer.addPages(TASK_1, ImmutableList.of(PAGE_1, PAGE_2)); + assertEquals(buffer.getBufferedPageCount(), 2); + assertEquals(buffer.getRetainedSizeInBytes(), PAGE_1.getRetainedSizeInBytes() + PAGE_2.getRetainedSizeInBytes()); + assertEquals(buffer.getMaxRetainedSizeInBytes(), PAGE_1.getRetainedSizeInBytes() + PAGE_2.getRetainedSizeInBytes()); + assertEquals(buffer.getRemainingCapacityInBytes(), DataSize.of(1, KILOBYTE).toBytes() - PAGE_1.getRetainedSizeInBytes() - PAGE_2.getRetainedSizeInBytes()); + assertFalse(buffer.isFinished()); + assertTrue(buffer.isBlocked().isDone()); + assertPageEquals(buffer.pollPage(), PAGE_1); + assertPageEquals(buffer.pollPage(), PAGE_2); + assertFalse(buffer.isFinished()); + assertFalse(buffer.isBlocked().isDone()); + assertEquals(buffer.getRetainedSizeInBytes(), 0); + assertEquals(buffer.getMaxRetainedSizeInBytes(), PAGE_1.getRetainedSizeInBytes() + PAGE_2.getRetainedSizeInBytes()); + assertEquals(buffer.getRemainingCapacityInBytes(), DataSize.of(1, KILOBYTE).toBytes()); + + buffer.taskFinished(TASK_1); + assertTrue(buffer.isFinished()); + assertTrue(buffer.isBlocked().isDone()); + } + } + + @Test + public void testClose() + { + StreamingExchangeClientBuffer buffer = new StreamingExchangeClientBuffer(directExecutor(), DataSize.of(1, KILOBYTE)); + buffer.addTask(TASK_0); + buffer.addTask(TASK_1); + + assertFalse(buffer.isFinished()); + assertFalse(buffer.isBlocked().isDone()); + assertNull(buffer.pollPage()); + + buffer.close(); + + assertTrue(buffer.isFinished()); + assertTrue(buffer.isBlocked().isDone()); + assertNull(buffer.pollPage()); + } + + @Test + public void testIsFinished() + { + // 0 tasks + try (StreamingExchangeClientBuffer buffer = new StreamingExchangeClientBuffer(directExecutor(), DataSize.of(1, KILOBYTE))) { + assertFalse(buffer.isFinished()); + assertFalse(buffer.isBlocked().isDone()); + + buffer.noMoreTasks(); + + assertTrue(buffer.isFinished()); + assertTrue(buffer.isBlocked().isDone()); + } + + // single task + try (StreamingExchangeClientBuffer buffer = new StreamingExchangeClientBuffer(directExecutor(), DataSize.of(1, KILOBYTE))) { + assertFalse(buffer.isFinished()); + assertFalse(buffer.isBlocked().isDone()); + + buffer.addTask(TASK_0); + buffer.noMoreTasks(); + + assertFalse(buffer.isFinished()); + assertFalse(buffer.isBlocked().isDone()); + + buffer.taskFinished(TASK_0); + + assertTrue(buffer.isFinished()); + assertTrue(buffer.isBlocked().isDone()); + } + + // single failed task + try (StreamingExchangeClientBuffer buffer = new StreamingExchangeClientBuffer(directExecutor(), DataSize.of(1, KILOBYTE))) { + assertFalse(buffer.isFinished()); + assertFalse(buffer.isBlocked().isDone()); + + buffer.addTask(TASK_0); + + assertFalse(buffer.isFinished()); + assertFalse(buffer.isBlocked().isDone()); + + RuntimeException error = new RuntimeException(); + buffer.taskFailed(TASK_0, error); + + assertTrue(buffer.isFinished()); + assertTrue(buffer.isBlocked().isDone()); + assertThatThrownBy(buffer::pollPage).isEqualTo(error); + } + } + + @Test + public void testFutureCancellationDoesNotAffectOtherFutures() + { + try (StreamingExchangeClientBuffer buffer = new StreamingExchangeClientBuffer(directExecutor(), DataSize.of(1, KILOBYTE))) { + assertFalse(buffer.isFinished()); + + ListenableFuture blocked1 = buffer.isBlocked(); + ListenableFuture blocked2 = buffer.isBlocked(); + ListenableFuture blocked3 = buffer.isBlocked(); + + assertFalse(blocked1.isDone()); + assertFalse(blocked2.isDone()); + assertFalse(blocked3.isDone()); + + blocked3.cancel(true); + assertFalse(blocked1.isDone()); + assertFalse(blocked2.isDone()); + + buffer.noMoreTasks(); + + assertTrue(buffer.isFinished()); + assertTrue(blocked1.isDone()); + assertTrue(blocked2.isDone()); + } + } + + private static SerializedPage createPage(String value) + { + return new SerializedPage(utf8Slice(value), PageCodecMarker.MarkerSet.empty(), 1, value.length()); + } + + private static void assertPageEquals(SerializedPage actual, SerializedPage expected) + { + assertEquals(actual.getPositionCount(), expected.getPositionCount()); + assertEquals(actual.getUncompressedSizeInBytes(), expected.getUncompressedSizeInBytes()); + assertEquals(actual.getPageCodecMarkers(), expected.getPageCodecMarkers()); + assertEquals(actual.getSlice(), expected.getSlice()); + } +} diff --git a/core/trino-main/src/test/java/io/trino/operator/TestingExchangeClientBuffer.java b/core/trino-main/src/test/java/io/trino/operator/TestingExchangeClientBuffer.java new file mode 100644 index 000000000000..c26891bc9722 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/operator/TestingExchangeClientBuffer.java @@ -0,0 +1,184 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator; + +import com.google.common.collect.ArrayListMultimap; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.ListMultimap; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; +import io.airlift.units.DataSize; +import io.trino.execution.TaskId; +import io.trino.execution.buffer.SerializedPage; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.util.concurrent.Futures.immediateVoidFuture; +import static java.util.Objects.requireNonNull; + +public class TestingExchangeClientBuffer + implements ExchangeClientBuffer +{ + private ListenableFuture blocked = immediateVoidFuture(); + private final Set allTasks = new HashSet<>(); + private final ListMultimap pages = ArrayListMultimap.create(); + private final Set finishedTasks = new HashSet<>(); + private final ListMultimap failedTasks = ArrayListMultimap.create(); + private boolean noMoreTasks; + private boolean finished; + private final long remainingBufferCapacityInBytes; + + private final Map> taskFinished = new HashMap<>(); + private final Map> taskFailed = new HashMap<>(); + + public TestingExchangeClientBuffer(DataSize bufferCapacity) + { + this.remainingBufferCapacityInBytes = bufferCapacity.toBytes(); + } + + @Override + public synchronized ListenableFuture isBlocked() + { + return blocked; + } + + public synchronized void setBlocked(ListenableFuture blocked) + { + this.blocked = requireNonNull(blocked, "blocked is null"); + } + + @Override + public synchronized SerializedPage pollPage() + { + return null; + } + + @Override + public synchronized void addTask(TaskId taskId) + { + checkState(allTasks.add(taskId), "task is already present: %s", taskId); + } + + public synchronized Set getAllTasks() + { + return ImmutableSet.copyOf(allTasks); + } + + @Override + public synchronized void addPages(TaskId taskId, List pages) + { + checkState(allTasks.contains(taskId), "task is expected to be present: %s", taskId); + this.pages.putAll(taskId, pages); + } + + public synchronized ListMultimap getPages() + { + return ImmutableListMultimap.copyOf(pages); + } + + @Override + public synchronized void taskFinished(TaskId taskId) + { + checkState(allTasks.contains(taskId), "task is expected to be present: %s", taskId); + checkState(finishedTasks.add(taskId), "task is already finished: %s", taskId); + taskFinished.computeIfAbsent(taskId, key -> SettableFuture.create()).set(null); + } + + public synchronized Set getFinishedTasks() + { + return ImmutableSet.copyOf(finishedTasks); + } + + public synchronized ListenableFuture whenTaskFinished(TaskId taskId) + { + return taskFinished.computeIfAbsent(taskId, key -> SettableFuture.create()); + } + + @Override + public synchronized void taskFailed(TaskId taskId, Throwable t) + { + checkState(allTasks.contains(taskId), "task is expected to be present: %s", taskId); + checkState(!finishedTasks.contains(taskId), "task is already finished: %s", taskId); + failedTasks.put(taskId, t); + taskFailed.computeIfAbsent(taskId, key -> SettableFuture.create()).set(null); + } + + public synchronized ListMultimap getFailedTasks() + { + return ImmutableListMultimap.copyOf(failedTasks); + } + + public synchronized ListenableFuture whenTaskFailed(TaskId taskId) + { + return taskFailed.computeIfAbsent(taskId, key -> SettableFuture.create()); + } + + @Override + public synchronized void noMoreTasks() + { + noMoreTasks = true; + } + + public synchronized boolean isNoMoreTasks() + { + return noMoreTasks; + } + + @Override + public synchronized boolean isFinished() + { + return finished; + } + + public synchronized void setFinished(boolean finished) + { + this.finished = finished; + } + + @Override + public synchronized long getRemainingCapacityInBytes() + { + return remainingBufferCapacityInBytes; + } + + @Override + public synchronized long getRetainedSizeInBytes() + { + return 0; + } + + @Override + public synchronized long getMaxRetainedSizeInBytes() + { + return 0; + } + + @Override + public synchronized int getBufferedPageCount() + { + return 0; + } + + @Override + public synchronized void close() + { + finished = true; + } +} diff --git a/core/trino-main/src/test/java/io/trino/operator/TestingExchangeHttpClientHandler.java b/core/trino-main/src/test/java/io/trino/operator/TestingExchangeHttpClientHandler.java index 2dab0b73dad0..8d3a8605f856 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestingExchangeHttpClientHandler.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestingExchangeHttpClientHandler.java @@ -23,6 +23,7 @@ import io.airlift.http.client.testing.TestingHttpClient; import io.airlift.http.client.testing.TestingResponse; import io.airlift.slice.DynamicSliceOutput; +import io.trino.execution.TaskId; import io.trino.execution.buffer.PagesSerde; import io.trino.execution.buffer.SerializedPage; import io.trino.spi.Page; @@ -34,6 +35,7 @@ import static io.trino.server.InternalHeaders.TRINO_BUFFER_COMPLETE; import static io.trino.server.InternalHeaders.TRINO_PAGE_NEXT_TOKEN; import static io.trino.server.InternalHeaders.TRINO_PAGE_TOKEN; +import static io.trino.server.InternalHeaders.TRINO_TASK_FAILED; import static io.trino.server.InternalHeaders.TRINO_TASK_INSTANCE_ID; import static io.trino.server.PagesResponseWriter.SERIALIZED_PAGES_MAGIC; import static java.util.Objects.requireNonNull; @@ -45,9 +47,9 @@ public class TestingExchangeHttpClientHandler { private static final PagesSerde PAGES_SERDE = testingPagesSerde(); - private final LoadingCache taskBuffers; + private final LoadingCache taskBuffers; - public TestingExchangeHttpClientHandler(LoadingCache taskBuffers) + public TestingExchangeHttpClientHandler(LoadingCache taskBuffers) { this.taskBuffers = requireNonNull(taskBuffers, "taskBuffers is null"); } @@ -62,12 +64,13 @@ public Response handle(Request request) } assertEquals(parts.size(), 2); - String taskId = parts.get(0); + TaskId taskId = TaskId.valueOf(parts.get(0)); int pageToken = Integer.parseInt(parts.get(1)); ImmutableListMultimap.Builder headers = ImmutableListMultimap.builder(); headers.put(TRINO_TASK_INSTANCE_ID, "task-instance-id"); headers.put(TRINO_PAGE_TOKEN, String.valueOf(pageToken)); + headers.put(TRINO_TASK_FAILED, "false"); TestingTaskBuffer taskBuffer = taskBuffers.getUnchecked(taskId); Page page = taskBuffer.getPage(pageToken); diff --git a/core/trino-main/src/test/java/io/trino/operator/join/TestHashJoinOperator.java b/core/trino-main/src/test/java/io/trino/operator/join/TestHashJoinOperator.java index 7be4927314b5..cc8865aad78b 100644 --- a/core/trino-main/src/test/java/io/trino/operator/join/TestHashJoinOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/join/TestHashJoinOperator.java @@ -24,6 +24,7 @@ import io.trino.RowPagesBuilder; import io.trino.execution.Lifespan; import io.trino.execution.NodeTaskMap; +import io.trino.execution.StageId; import io.trino.execution.TaskId; import io.trino.execution.TaskStateMachine; import io.trino.execution.scheduler.NodeScheduler; @@ -480,7 +481,7 @@ public void testInnerJoinWithFailingSpill(boolean probeHashEnabled, List whenSpill, SingleStreamSpillerFactory buildSpillerFactory, PartitioningSpillerFactory joinSpillerFactory) throws Exception { - TaskStateMachine taskStateMachine = new TaskStateMachine(new TaskId("query", 0, 0), executor); + TaskStateMachine taskStateMachine = new TaskStateMachine(new TaskId(new StageId("query", 0), 0, 0), executor); TaskContext taskContext = TestingTaskContext.createTaskContext(executor, scheduledExecutor, TEST_SESSION, taskStateMachine); DriverContext joinDriverContext = taskContext.addPipelineContext(2, true, true, false).addDriverContext(); @@ -657,7 +658,7 @@ private static MaterializedResult getProperColumns(Operator joinOperator, List summary = dynamicFilterService.getSummary(queryId, filterId); assertTrue(summary.isPresent()); @@ -201,7 +202,7 @@ symbol3, new TestingColumnHandle("probeColumnB")), assertFalse(blockedFuture.isDone()); dynamicFilterService.addTaskDynamicFilters( - new TaskId(stageId1, 0), + new TaskId(stageId1, 0, 0), ImmutableMap.of(filterId1, singleValue(INTEGER, 1L))); // tuple domain from two tasks are needed for dynamic filter to be narrowed down @@ -211,7 +212,7 @@ symbol3, new TestingColumnHandle("probeColumnB")), assertFalse(blockedFuture.isDone()); dynamicFilterService.addTaskDynamicFilters( - new TaskId(stageId1, 1), + new TaskId(stageId1, 1, 0), ImmutableMap.of(filterId1, singleValue(INTEGER, 2L))); // dynamic filter (id1) has been collected as tuple domains from two tasks have been provided @@ -231,7 +232,7 @@ symbol3, new TestingColumnHandle("probeColumnB")), assertFalse(blockedFuture.isDone()); dynamicFilterService.addTaskDynamicFilters( - new TaskId(stageId2, 0), + new TaskId(stageId2, 0, 0), ImmutableMap.of(filterId2, singleValue(INTEGER, 2L))); // tuple domain from two tasks (stage 2) are needed for dynamic filter to be narrowed down @@ -246,7 +247,7 @@ symbol3, new TestingColumnHandle("probeColumnB")), assertEquals(stats.getDynamicFiltersCompleted(), 1); dynamicFilterService.addTaskDynamicFilters( - new TaskId(stageId2, 1), + new TaskId(stageId2, 1, 0), ImmutableMap.of(filterId2, singleValue(INTEGER, 3L))); // dynamic filter (id2) has been collected as tuple domains from two tasks have been provided @@ -285,7 +286,7 @@ symbol2, new TestingColumnHandle("probeColumnA")), singleValue(INTEGER, 2L)))); dynamicFilterService.addTaskDynamicFilters( - new TaskId(stageId3, 0), + new TaskId(stageId3, 0, 0), ImmutableMap.of(filterId3, none(INTEGER))); // tuple domain from two tasks (stage 3) are needed for dynamic filter to be narrowed down @@ -300,7 +301,7 @@ symbol2, new TestingColumnHandle("probeColumnA")), assertEquals(stats.getDynamicFiltersCompleted(), 2); dynamicFilterService.addTaskDynamicFilters( - new TaskId(stageId3, 1), + new TaskId(stageId3, 1, 0), ImmutableMap.of(filterId3, none(INTEGER))); // "none" dynamic filter (id3) has been collected for column B as tuple domains from two tasks have been provided @@ -357,7 +358,7 @@ symbol1, new TestingColumnHandle("probeColumnA")), assertFalse(dynamicFilter.isBlocked().isDone()); dynamicFilterService.addTaskDynamicFilters( - new TaskId(stageId1, 1), + new TaskId(stageId1, 1, 0), ImmutableMap.of(filterId1, Domain.all(INTEGER))); // dynamic filter should be unblocked and completed @@ -397,7 +398,7 @@ public void testDynamicFilterCoercion() assertTrue(dynamicFilter.getCurrentPredicate().isAll()); dynamicFilterService.addTaskDynamicFilters( - new TaskId(stageId1, 0), + new TaskId(stageId1, 0, 0), ImmutableMap.of(filterId1, multipleValues(BIGINT, ImmutableList.of(1L, 2L, 3L)))); assertTrue(dynamicFilter.isComplete()); assertEquals(dynamicFilter.getCurrentPredicate(), TupleDomain.withColumnDomains(ImmutableMap.of( @@ -446,7 +447,7 @@ symbol1, new TestingColumnHandle("probeColumnA")), assertTrue(dynamicFilter.isBlocked().isDone()); dynamicFilterService.addTaskDynamicFilters( - new TaskId(stageId1, 0), + new TaskId(stageId1, 0, 0), ImmutableMap.of(filterId1, singleValue(INTEGER, 1L))); // tuple domain from single broadcast join task is sufficient @@ -499,7 +500,7 @@ public void testStageCannotScheduleMoreTasks() // adding task dynamic filters shouldn't complete dynamic filter dynamicFilterService.addTaskDynamicFilters( - new TaskId(stageId1, 0), + new TaskId(stageId1, 0, 0), ImmutableMap.of(filterId1, singleValue(INTEGER, 1L))); assertTrue(dynamicFilter.getCurrentPredicate().isAll()); @@ -542,7 +543,7 @@ public void testDynamicFilterCancellation() assertEquals(dynamicFilter.getCurrentPredicate(), TupleDomain.all()); dynamicFilterService.addTaskDynamicFilters( - new TaskId(stageId, 0), + new TaskId(stageId, 0, 0), ImmutableMap.of(filterId, singleValue(INTEGER, 1L))); assertEquals(dynamicFilter.getCurrentPredicate(), TupleDomain.all()); @@ -554,7 +555,7 @@ public void testDynamicFilterCancellation() assertFalse(dynamicFilter.isComplete()); dynamicFilterService.addTaskDynamicFilters( - new TaskId(stageId, 1), + new TaskId(stageId, 1, 0), ImmutableMap.of(filterId, singleValue(INTEGER, 2L))); assertTrue(isBlocked.isDone()); assertTrue(dynamicFilter.isComplete()); @@ -635,7 +636,7 @@ public void testMultipleColumnMapping() Domain domain = singleValue(INTEGER, 1L); dynamicFilterService.addTaskDynamicFilters( - new TaskId(stageId1, 0), + new TaskId(stageId1, 0, 0), ImmutableMap.of(filterId1, domain)); assertEquals( @@ -666,13 +667,13 @@ public void testDynamicFilterConsumer() assertTrue(consumerCollectedFilters.isEmpty()); dynamicFilterService.addTaskDynamicFilters( - new TaskId(stageId, 0), + new TaskId(stageId, 0, 0), ImmutableMap.of(filterId1, singleValue(INTEGER, 1L))); assertTrue(consumerCollectedFilters.isEmpty()); // complete only filterId1 dynamicFilterService.addTaskDynamicFilters( - new TaskId(stageId, 1), + new TaskId(stageId, 1, 0), ImmutableMap.of( filterId1, singleValue(INTEGER, 3L), filterId2, singleValue(INTEGER, 2L))); @@ -688,7 +689,7 @@ filterId1, singleValue(INTEGER, 3L), // complete filterId2 dynamicFilterService.addTaskDynamicFilters( - new TaskId(stageId, 0), + new TaskId(stageId, 0, 0), ImmutableMap.of(filterId2, singleValue(INTEGER, 4L))); assertEquals( consumerCollectedFilters, @@ -725,7 +726,7 @@ public void testDynamicFilterConsumerCallbackCount() assertTrue(consumerCollectedFilters.isEmpty()); dynamicFilterService.addTaskDynamicFilters( - new TaskId(stageId, 0), + new TaskId(stageId, 0, 0), ImmutableMap.of( filterId1, singleValue(INTEGER, 1L), filterId2, singleValue(INTEGER, 2L))); @@ -733,7 +734,7 @@ filterId1, singleValue(INTEGER, 1L), // complete both filterId1 and filterId2 dynamicFilterService.addTaskDynamicFilters( - new TaskId(stageId, 1), + new TaskId(stageId, 1, 0), ImmutableMap.of( filterId1, singleValue(INTEGER, 3L), filterId2, singleValue(INTEGER, 4L))); @@ -813,7 +814,7 @@ private static PlanFragment createPlan( tableScan, createDynamicFilterExpression(session, createTestMetadataManager(), consumedDynamicFilterId, VARCHAR, symbol.toSymbolReference())); - RemoteSourceNode remote = new RemoteSourceNode(new PlanNodeId("remote_id"), new PlanFragmentId("plan_fragment_id"), ImmutableList.of(buildSymbol), Optional.empty(), exchangeType); + RemoteSourceNode remote = new RemoteSourceNode(new PlanNodeId("remote_id"), new PlanFragmentId("plan_fragment_id"), ImmutableList.of(buildSymbol), Optional.empty(), exchangeType, RetryPolicy.NONE); return new PlanFragment( new PlanFragmentId("plan_id"), new JoinNode(new PlanNodeId("join_id"), diff --git a/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java b/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java index a259b8bec336..03f284d01c1d 100644 --- a/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java +++ b/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java @@ -312,7 +312,7 @@ public void testOutboundDynamicFilters() // make sure initial dynamic filter is collected CompletableFuture future = dynamicFilter.isBlocked(); dynamicFilterService.addTaskDynamicFilters( - new TaskId(queryId.getId(), 1, 1), + new TaskId(new StageId(queryId.getId(), 1), 1, 0), ImmutableMap.of(filterId1, Domain.singleValue(BIGINT, 1L))); future.get(); assertEquals( @@ -342,7 +342,7 @@ public void testOutboundDynamicFilters() future = dynamicFilter.isBlocked(); dynamicFilterService.addTaskDynamicFilters( - new TaskId(queryId.getId(), 1, 1), + new TaskId(new StageId(queryId.getId(), 1), 1, 0), ImmutableMap.of(filterId2, Domain.singleValue(BIGINT, 2L))); future.get(); assertEquals( @@ -412,7 +412,7 @@ private RemoteTask createRemoteTask(HttpRemoteTaskFactory httpRemoteTaskFactory, { return httpRemoteTaskFactory.createRemoteTask( TEST_SESSION, - new TaskId("test", 1, 2), + new TaskId(new StageId("test", 1), 2, 0), new InternalNode("node-id", URI.create("http://fake.invalid/"), new NodeVersion("version"), false), TaskTestUtils.PLAN_FRAGMENT, ImmutableMultimap.of(), diff --git a/core/trino-main/src/test/java/io/trino/sql/analyzer/TestFeaturesConfig.java b/core/trino-main/src/test/java/io/trino/sql/analyzer/TestFeaturesConfig.java index 5449565c9e60..1709da1a6457 100644 --- a/core/trino-main/src/test/java/io/trino/sql/analyzer/TestFeaturesConfig.java +++ b/core/trino-main/src/test/java/io/trino/sql/analyzer/TestFeaturesConfig.java @@ -20,6 +20,7 @@ import io.trino.FeaturesConfig.DataIntegrityVerification; import io.trino.FeaturesConfig.JoinDistributionType; import io.trino.FeaturesConfig.JoinReorderingStrategy; +import io.trino.operator.RetryPolicy; import org.testng.annotations.Test; import java.util.Map; @@ -34,6 +35,7 @@ import static io.trino.FeaturesConfig.JoinReorderingStrategy.NONE; import static io.trino.sql.analyzer.RegexLibrary.JONI; import static io.trino.sql.analyzer.RegexLibrary.RE2J; +import static java.util.concurrent.TimeUnit.HOURS; import static java.util.concurrent.TimeUnit.MINUTES; import static java.util.concurrent.TimeUnit.SECONDS; @@ -115,7 +117,11 @@ public void testDefaults() .setMergeProjectWithValues(true) .setLegacyCatalogRoles(false) .setDisableSetPropertiesSecurityCheckForCreateDdl(false) - .setIncrementalHashArrayLoadFactorEnabled(true)); + .setIncrementalHashArrayLoadFactorEnabled(true) + .setRetryPolicy(RetryPolicy.NONE) + .setRetryAttempts(4) + .setRetryInitialDelay(new Duration(10, SECONDS)) + .setRetryMaxDelay(new Duration(1, MINUTES))); } @Test @@ -195,6 +201,10 @@ public void testExplicitPropertyMappings() .put("deprecated.legacy-catalog-roles", "true") .put("deprecated.disable-set-properties-security-check-for-create-ddl", "true") .put("incremental-hash-array-load-factor.enabled", "false") + .put("retry-policy", "QUERY") + .put("retry-attempts", "0") + .put("retry-initial-delay", "1m") + .put("retry-max-delay", "1h") .build(); FeaturesConfig expected = new FeaturesConfig() @@ -270,7 +280,11 @@ public void testExplicitPropertyMappings() .setMergeProjectWithValues(false) .setLegacyCatalogRoles(true) .setDisableSetPropertiesSecurityCheckForCreateDdl(true) - .setIncrementalHashArrayLoadFactorEnabled(false); + .setIncrementalHashArrayLoadFactorEnabled(false) + .setRetryPolicy(RetryPolicy.QUERY) + .setRetryAttempts(0) + .setRetryInitialDelay(new Duration(1, MINUTES)) + .setRetryMaxDelay(new Duration(1, HOURS)); assertFullMapping(properties, expected); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestShowQueries.java b/core/trino-main/src/test/java/io/trino/sql/query/TestShowQueries.java index e6a6c2251caf..ee3714e5ab9f 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/TestShowQueries.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestShowQueries.java @@ -114,7 +114,7 @@ public void testShowSessionLike() { assertThat(assertions.query( "SHOW SESSION LIKE '%page_row_c%'")) - .matches("VALUES (cast('filter_and_project_min_output_page_row_count' as VARCHAR(53)), cast('256' as VARCHAR(14)), cast('256' as VARCHAR(14)), 'integer', cast('Experimental: Minimum output page row count for filter and project operators' as VARCHAR(103)))"); + .matches("VALUES (cast('filter_and_project_min_output_page_row_count' as VARCHAR(53)), cast('256' as VARCHAR(14)), cast('256' as VARCHAR(14)), 'integer', cast('Experimental: Minimum output page row count for filter and project operators' as VARCHAR(142)))"); } @Test @@ -126,7 +126,7 @@ public void testShowSessionLikeWithEscape() .hasMessage("Escape string must be a single character"); assertThat(assertions.query( "SHOW SESSION LIKE '%page$_row$_c%' ESCAPE '$'")) - .matches("VALUES (cast('filter_and_project_min_output_page_row_count' as VARCHAR(53)), cast('256' as VARCHAR(14)), cast('256' as VARCHAR(14)), 'integer', cast('Experimental: Minimum output page row count for filter and project operators' as VARCHAR(103)))"); + .matches("VALUES (cast('filter_and_project_min_output_page_row_count' as VARCHAR(53)), cast('256' as VARCHAR(14)), cast('256' as VARCHAR(14)), 'integer', cast('Experimental: Minimum output page row count for filter and project operators' as VARCHAR(142)))"); } @Test diff --git a/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java b/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java index 6ab7979d1671..846b7ca7c247 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java +++ b/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java @@ -156,6 +156,7 @@ public enum StandardErrorCode CONFIGURATION_UNAVAILABLE(65560, INTERNAL_ERROR), INVALID_RESOURCE_GROUP(65561, INTERNAL_ERROR), SERIALIZATION_ERROR(65562, INTERNAL_ERROR), + REMOTE_TASK_FAILED(65563, INTERNAL_ERROR), GENERIC_INSUFFICIENT_RESOURCES(131072, INSUFFICIENT_RESOURCES), EXCEEDED_GLOBAL_MEMORY_LIMIT(131073, INSUFFICIENT_RESOURCES), diff --git a/plugin/trino-hive/pom.xml b/plugin/trino-hive/pom.xml index c7e900bf7a2f..1623967f0086 100644 --- a/plugin/trino-hive/pom.xml +++ b/plugin/trino-hive/pom.xml @@ -411,6 +411,7 @@ **/TestHiveGlueMetastore.java **/TestFullParquetReader.java + **/TestHiveFailureRecovery.java @@ -451,5 +452,22 @@ + + + test-failure-recovery + + + + org.apache.maven.plugins + maven-surefire-plugin + + + **/TestHiveFailureRecovery.java + + + + + + diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFailureRecovery.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFailureRecovery.java new file mode 100644 index 000000000000..ce2165575140 --- /dev/null +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFailureRecovery.java @@ -0,0 +1,143 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive; + +import io.trino.Session; +import io.trino.testing.AbstractTestFailureRecovery; +import io.trino.testing.QueryRunner; +import io.trino.tpch.TpchTable; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class TestHiveFailureRecovery + extends AbstractTestFailureRecovery +{ + @Override + protected QueryRunner createQueryRunner(List> requiredTpchTables, Map configProperties, Map coordinatorProperties) + throws Exception + { + return HiveQueryRunner.builder() + .setInitialTables(requiredTpchTables) + .setCoordinatorProperties(coordinatorProperties) + .setExtraProperties(configProperties) + .build(); + } + + @Override + // create table is not atomic at the moment + @Test(enabled = false) + public void testCreateTable() + { + super.testCreateTable(); + } + + @Override + // delete is unsupported for non ACID tables + public void testDelete() + { + assertThatThrownBy(super::testDelete) + .hasMessageContaining("Deletes must match whole partitions for non-transactional tables"); + } + + @Override + // delete is unsupported for non ACID tables + public void testDeleteWithSubquery() + { + assertThatThrownBy(super::testDelete) + .hasMessageContaining("Deletes must match whole partitions for non-transactional tables"); + } + + @Override + // update is unsupported for non ACID tables + public void testUpdate() + { + assertThatThrownBy(super::testUpdate) + .hasMessageContaining("Hive update is only supported for ACID transactional tables"); + } + + @Override + // update is unsupported for non ACID tables + public void testUpdateWithSubquery() + { + assertThatThrownBy(super::testUpdateWithSubquery) + .hasMessageContaining("Hive update is only supported for ACID transactional tables"); + } + + @Override + // materialized views are currently not implemented by Hive connector + public void testRefreshMaterializedView() + { + assertThatThrownBy(super::testRefreshMaterializedView) + .hasMessageContaining("This connector does not support creating materialized views"); + } + + @Test(invocationCount = INVOCATION_COUNT, enabled = false) + // create table is not atomic at the moment + public void testCreatePartitionedTable() + { + testTableModification( + Optional.empty(), + "CREATE TABLE WITH (partitioned_by = ARRAY['p']) AS SELECT *, 'partition1' p FROM orders", + Optional.of("DROP TABLE
")); + } + + @Test(invocationCount = INVOCATION_COUNT, enabled = false) + // create partition is not atomic at the moment + public void testInsertIntoNewPartition() + { + testTableModification( + Optional.of("CREATE TABLE
WITH (partitioned_by = ARRAY['p']) AS SELECT *, 'partition1' p FROM orders"), + "INSERT INTO
SELECT *, 'partition2' p FROM orders", + Optional.of("DROP TABLE
")); + } + + @Test(invocationCount = INVOCATION_COUNT) + public void testInsertIntoExistingPartition() + { + testTableModification( + Optional.of("CREATE TABLE
WITH (partitioned_by = ARRAY['p']) AS SELECT *, 'partition1' p FROM orders"), + "INSERT INTO
SELECT *, 'partition1' p FROM orders", + Optional.of("DROP TABLE
")); + } + + @Test(invocationCount = INVOCATION_COUNT, enabled = false) + // replace partition is not atomic at the moment + public void testReplaceExistingPartition() + { + testTableModification( + Optional.of(Session.builder(getQueryRunner().getDefaultSession()) + .setCatalogSessionProperty("hive", "insert_existing_partitions_behavior", "OVERWRITE") + .build()), + Optional.of("CREATE TABLE
WITH (partitioned_by = ARRAY['p']) AS SELECT *, 'partition1' p FROM orders"), + "INSERT INTO
SELECT *, 'partition1' p FROM orders", + Optional.of("DROP TABLE
")); + } + + @Test(invocationCount = INVOCATION_COUNT) + // delete is unsupported for non ACID tables + public void testDeletePartitionWithSubquery() + { + assertThatThrownBy(() -> { + testTableModification( + Optional.of("CREATE TABLE
WITH (partitioned_by = ARRAY['p']) AS SELECT *, 0 p FROM orders"), + "DELETE FROM
WHERE p = (SELECT min(nationkey) FROM nation)", + Optional.of("DROP TABLE
")); + }).hasMessageContaining("Deletes must match whole partitions for non-transactional tables"); + } +} diff --git a/plugin/trino-thrift/src/test/java/io/trino/plugin/thrift/integration/ThriftQueryRunner.java b/plugin/trino-thrift/src/test/java/io/trino/plugin/thrift/integration/ThriftQueryRunner.java index 87165830347b..19d58053970b 100644 --- a/plugin/trino-thrift/src/test/java/io/trino/plugin/thrift/integration/ThriftQueryRunner.java +++ b/plugin/trino-thrift/src/test/java/io/trino/plugin/thrift/integration/ThriftQueryRunner.java @@ -28,6 +28,7 @@ import io.airlift.log.Logging; import io.trino.Session; import io.trino.cost.StatsCalculator; +import io.trino.execution.FailureInjector.InjectedFailureType; import io.trino.metadata.Metadata; import io.trino.metadata.QualifiedObjectName; import io.trino.metadata.SqlFunction; @@ -35,6 +36,7 @@ import io.trino.plugin.thrift.server.ThriftIndexedTpchService; import io.trino.plugin.thrift.server.ThriftTpchService; import io.trino.server.testing.TestingTrinoServer; +import io.trino.spi.ErrorType; import io.trino.spi.Plugin; import io.trino.split.PageSourceManager; import io.trino.split.SplitManager; @@ -51,6 +53,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.concurrent.locks.Lock; import static io.airlift.testing.Closeables.closeAllSuppress; @@ -300,5 +303,17 @@ public Lock getExclusiveLock() { return source.getExclusiveLock(); } + + @Override + public void injectTaskFailure( + String traceToken, + int stageId, + int partitionId, + int attemptId, + InjectedFailureType injectionType, + Optional errorType) + { + source.injectTaskFailure(traceToken, stageId, partitionId, attemptId, injectionType, errorType); + } } } diff --git a/testing/trino-benchmark/src/main/java/io/trino/benchmark/AbstractOperatorBenchmark.java b/testing/trino-benchmark/src/main/java/io/trino/benchmark/AbstractOperatorBenchmark.java index 61ad4e0f679f..795f137431e4 100644 --- a/testing/trino-benchmark/src/main/java/io/trino/benchmark/AbstractOperatorBenchmark.java +++ b/testing/trino-benchmark/src/main/java/io/trino/benchmark/AbstractOperatorBenchmark.java @@ -21,6 +21,7 @@ import io.airlift.units.DataSize; import io.trino.Session; import io.trino.execution.Lifespan; +import io.trino.execution.StageId; import io.trino.execution.TaskId; import io.trino.execution.TaskStateMachine; import io.trino.memory.MemoryPool; @@ -305,7 +306,7 @@ protected Map runOnce() localQueryRunner.getScheduler(), DataSize.of(256, MEGABYTE), spillSpaceTracker) - .addTaskContext(new TaskStateMachine(new TaskId("query", 0, 0), localQueryRunner.getExecutor()), + .addTaskContext(new TaskStateMachine(new TaskId(new StageId("query", 0), 0, 0), localQueryRunner.getExecutor()), session, () -> {}, false, diff --git a/testing/trino-benchmark/src/test/java/io/trino/benchmark/MemoryLocalQueryRunner.java b/testing/trino-benchmark/src/test/java/io/trino/benchmark/MemoryLocalQueryRunner.java index 39a457e5b1b6..f7758a0e674e 100644 --- a/testing/trino-benchmark/src/test/java/io/trino/benchmark/MemoryLocalQueryRunner.java +++ b/testing/trino-benchmark/src/test/java/io/trino/benchmark/MemoryLocalQueryRunner.java @@ -18,6 +18,7 @@ import io.airlift.stats.TestingGcMonitor; import io.airlift.units.DataSize; import io.trino.Session; +import io.trino.execution.StageId; import io.trino.execution.TaskId; import io.trino.execution.TaskStateMachine; import io.trino.memory.MemoryPool; @@ -81,7 +82,7 @@ public List execute(@Language("SQL") String query) spillSpaceTracker); TaskContext taskContext = queryContext - .addTaskContext(new TaskStateMachine(new TaskId("query", 0, 0), localQueryRunner.getExecutor()), + .addTaskContext(new TaskStateMachine(new TaskId(new StageId("query", 0), 0, 0), localQueryRunner.getExecutor()), localQueryRunner.getDefaultSession(), () -> {}, false, diff --git a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestFailureRecovery.java b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestFailureRecovery.java new file mode 100644 index 000000000000..7d5d24c35c61 --- /dev/null +++ b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestFailureRecovery.java @@ -0,0 +1,603 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.testing; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.graph.Traverser; +import io.airlift.units.Duration; +import io.trino.Session; +import io.trino.client.StageStats; +import io.trino.client.StatementStats; +import io.trino.execution.FailureInjector.InjectedFailureType; +import io.trino.spi.ErrorType; +import io.trino.tpch.TpchTable; +import org.assertj.core.api.AbstractThrowableAssert; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ThreadLocalRandom; +import java.util.function.Function; +import java.util.function.Predicate; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Streams.stream; +import static io.trino.execution.FailureInjector.FAILURE_INJECTION_MESSAGE; +import static io.trino.execution.FailureInjector.InjectedFailureType.TASK_FAILURE; +import static io.trino.execution.FailureInjector.InjectedFailureType.TASK_GET_RESULTS_REQUEST_FAILURE; +import static io.trino.execution.FailureInjector.InjectedFailureType.TASK_GET_RESULTS_REQUEST_TIMEOUT; +import static io.trino.execution.FailureInjector.InjectedFailureType.TASK_MANAGEMENT_REQUEST_FAILURE; +import static io.trino.execution.FailureInjector.InjectedFailureType.TASK_MANAGEMENT_REQUEST_TIMEOUT; +import static io.trino.testing.QueryAssertions.assertEqualsIgnoreOrder; +import static io.trino.testing.sql.TestTable.randomTableSuffix; +import static io.trino.tpch.TpchTable.CUSTOMER; +import static io.trino.tpch.TpchTable.NATION; +import static io.trino.tpch.TpchTable.ORDERS; +import static java.lang.Integer.parseInt; +import static java.util.Locale.ENGLISH; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.testng.Assert.assertEquals; + +public abstract class AbstractTestFailureRecovery + extends AbstractTestQueryFramework +{ + protected static final int INVOCATION_COUNT = 3; + private static final Duration MAX_ERROR_DURATION = new Duration(10, SECONDS); + private static final Duration REQUEST_TIMEOUT = new Duration(10, SECONDS); + + @Override + protected final QueryRunner createQueryRunner() + throws Exception + { + return createQueryRunner( + ImmutableList.of(NATION, ORDERS, CUSTOMER), + ImmutableMap.builder() + .put("query.remote-task.max-error-duration", MAX_ERROR_DURATION.toString()) + .put("exchange.max-error-duration", MAX_ERROR_DURATION.toString()) + .put("retry-policy", "QUERY") + .put("retry-initial-delay", "0s") + .put("retry-attempts", "1") + .put("failure-injection.request-timeout", new Duration(REQUEST_TIMEOUT.toMillis() * 2, MILLISECONDS).toString()) + .put("exchange.http-client.idle-timeout", REQUEST_TIMEOUT.toString()) + // TODO: re-enable once failure recover supported for this functionality + .put("enable-dynamic-filtering", "false") + .put("distributed-sort", "false") + .build(), + ImmutableMap.builder() + .put("scheduler.http-client.idle-timeout", REQUEST_TIMEOUT.toString()) + .build()); + } + + protected abstract QueryRunner createQueryRunner(List> requiredTpchTables, Map configProperties, Map coordinatorProperties) + throws Exception; + + @Test(invocationCount = INVOCATION_COUNT) + public void testSimpleSelect() + { + testSelect("SELECT * FROM nation"); + } + + @Test(invocationCount = INVOCATION_COUNT) + public void testAggregation() + { + testSelect("SELECT orderStatus, count(*) FROM orders GROUP BY orderStatus"); + } + + @Test(invocationCount = INVOCATION_COUNT) + public void testJoin() + { + testSelect("SELECT * FROM orders o, customer c WHERE o.custkey = c.custkey AND c.nationKey = 1"); + } + + protected void testSelect(String query) + { + assertThatQuery(query) + .experiencing(TASK_MANAGEMENT_REQUEST_FAILURE) + .at(leafStage()) + .finishesSuccessfully(); + + assertThatQuery(query) + .experiencing(TASK_GET_RESULTS_REQUEST_FAILURE) + .at(boundaryDistributedStage()) + .finishesSuccessfully(); + + assertThatQuery(query) + .experiencing(TASK_FAILURE, Optional.of(ErrorType.INTERNAL_ERROR)) + .at(leafStage()) + .finishesSuccessfully(); + + assertThatQuery(query) + .experiencing(TASK_FAILURE, Optional.of(ErrorType.EXTERNAL)) + .at(intermediateDistributedStage()) + .finishesSuccessfully(); + } + + @Test(invocationCount = INVOCATION_COUNT) + public void testUserFailure() + { + assertThatThrownBy(() -> getQueryRunner().execute("SELECT * FROM nation WHERE regionKey / nationKey - 1 = 0")) + .hasMessageContaining("Division by zero"); + + assertThatQuery("SELECT * FROM nation") + .experiencing(TASK_FAILURE, Optional.of(ErrorType.USER_ERROR)) + .at(leafStage()) + .failsWithErrorThat() + .hasMessageContaining(FAILURE_INJECTION_MESSAGE); + } + + @Test(invocationCount = INVOCATION_COUNT) + public void testCreateTable() + { + testTableModification( + Optional.empty(), + "CREATE TABLE
AS SELECT * FROM orders", + Optional.of("DROP TABLE
")); + } + + @Test(invocationCount = INVOCATION_COUNT) + public void testInsert() + { + testTableModification( + Optional.of("CREATE TABLE
AS SELECT * FROM orders WITH NO DATA"), + "INSERT INTO
SELECT * FROM orders", + Optional.of("DROP TABLE
")); + } + + @Test(invocationCount = INVOCATION_COUNT) + public void testDelete() + { + testTableModification( + Optional.of("CREATE TABLE
AS SELECT * FROM orders"), + "DELETE FROM orders WHERE orderkey = 1", + Optional.of("DROP TABLE
")); + } + + @Test(invocationCount = INVOCATION_COUNT) + public void testDeleteWithSubquery() + { + testTableModification( + Optional.of("CREATE TABLE
AS SELECT * FROM orders"), + "DELETE FROM orders WHERE custkey IN (SELECT custkey FROM customer WHERE nationkey = 1)", + Optional.of("DROP TABLE
")); + } + + @Test(invocationCount = INVOCATION_COUNT) + public void testUpdate() + { + testTableModification( + Optional.of("CREATE TABLE
AS SELECT * FROM orders"), + "UPDATE orders SET shippriority = 101 WHERE custkey = 1", + Optional.of("DROP TABLE
")); + } + + @Test(invocationCount = INVOCATION_COUNT) + public void testUpdateWithSubquery() + { + testTableModification( + Optional.of("CREATE TABLE
AS SELECT * FROM orders"), + "UPDATE orders SET shippriority = 101 WHERE custkey = (SELECT min(custkey) FROM customer)", + Optional.of("DROP TABLE
")); + } + + @Test(invocationCount = INVOCATION_COUNT) + public void testAnalyzeStatistics() + { + testTableModification( + Optional.of("CREATE TABLE
AS SELECT * FROM orders"), + "ANALYZE
", + Optional.of("DROP TABLE
")); + } + + @Test(invocationCount = INVOCATION_COUNT) + public void testRefreshMaterializedView() + { + testTableModification( + Optional.of("CREATE MATERIALIZED VIEW
AS SELECT * FROM orders"), + "REFRESH MATERIALIZED VIEW
", + Optional.of("DROP MATERIALIZED VIEW
")); + } + + @Test(invocationCount = INVOCATION_COUNT) + public void testExplainAnalyze() + { + testSelect("EXPLAIN ANALYZE SELECT orderStatus, count(*) FROM orders GROUP BY orderStatus"); + + testTableModification( + Optional.of("CREATE TABLE
AS SELECT * FROM orders WITH NO DATA"), + "EXPLAIN ANALYZE INSERT INTO
SELECT * FROM orders", + Optional.of("DROP TABLE
")); + } + + @Test(invocationCount = INVOCATION_COUNT) + public void testRequestTimeouts() + { + assertThatQuery("SELECT orderStatus, count(*) FROM orders GROUP BY orderStatus") + .experiencing(TASK_MANAGEMENT_REQUEST_TIMEOUT) + .at(intermediateDistributedStage()) + .finishesSuccessfully(); + + assertThatQuery("SELECT * FROM nation") + .experiencing(TASK_MANAGEMENT_REQUEST_TIMEOUT) + .at(boundaryDistributedStage()) + .finishesSuccessfully(); + + assertThatQuery("SELECT * FROM orders o, customer c WHERE o.custkey = c.custkey AND c.nationKey = 1") + .experiencing(TASK_GET_RESULTS_REQUEST_TIMEOUT) + .at(boundaryDistributedStage()) + .finishesSuccessfully(); + + assertThatQuery("INSERT INTO
SELECT * FROM orders") + .withSetupQuery(Optional.of("CREATE TABLE
AS SELECT * FROM orders WITH NO DATA")) + .withCleanupQuery(Optional.of("DROP TABLE
")) + .experiencing(TASK_MANAGEMENT_REQUEST_TIMEOUT) + .at(boundaryDistributedStage()) + .finishesSuccessfully(); + + assertThatQuery("INSERT INTO
SELECT * FROM orders") + .withSetupQuery(Optional.of("CREATE TABLE
AS SELECT * FROM orders WITH NO DATA")) + .withCleanupQuery(Optional.of("DROP TABLE
")) + .experiencing(TASK_GET_RESULTS_REQUEST_TIMEOUT) + .at(boundaryDistributedStage()) + .finishesSuccessfully(); + } + + protected void testTableModification(Optional setupQuery, String query, Optional cleanupQuery) + { + testTableModification(Optional.empty(), setupQuery, query, cleanupQuery); + } + + protected void testTableModification(Optional session, Optional setupQuery, String query, Optional cleanupQuery) + { + assertThatQuery(query) + .withSession(session) + .withSetupQuery(setupQuery) + .withCleanupQuery(cleanupQuery) + .experiencing(TASK_FAILURE, Optional.of(ErrorType.INTERNAL_ERROR)) + .at(boundaryCoordinatorStage()) + .failsWithErrorThat() + .hasMessageContaining(FAILURE_INJECTION_MESSAGE); + + assertThatQuery(query) + .withSession(session) + .withSetupQuery(setupQuery) + .withCleanupQuery(cleanupQuery) + .experiencing(TASK_FAILURE, Optional.of(ErrorType.INTERNAL_ERROR)) + .at(rootStage()) + .failsWithErrorThat() + .hasMessageContaining(FAILURE_INJECTION_MESSAGE); + + assertThatQuery(query) + .withSession(session) + .withSetupQuery(setupQuery) + .withCleanupQuery(cleanupQuery) + .experiencing(TASK_FAILURE, Optional.of(ErrorType.INTERNAL_ERROR)) + .at(leafStage()) + .finishesSuccessfully(); + + assertThatQuery(query) + .withSession(session) + .withSetupQuery(setupQuery) + .withCleanupQuery(cleanupQuery) + .experiencing(TASK_FAILURE, Optional.of(ErrorType.INTERNAL_ERROR)) + .at(boundaryDistributedStage()) + .finishesSuccessfully(); + + assertThatQuery(query) + .withSession(session) + .withSetupQuery(setupQuery) + .withCleanupQuery(cleanupQuery) + .experiencing(TASK_FAILURE, Optional.of(ErrorType.INTERNAL_ERROR)) + .at(intermediateDistributedStage()) + .finishesSuccessfully(); + + assertThatQuery(query) + .withSession(session) + .withSetupQuery(setupQuery) + .withCleanupQuery(cleanupQuery) + .experiencing(TASK_MANAGEMENT_REQUEST_FAILURE) + .at(boundaryDistributedStage()) + .finishesSuccessfully(); + + assertThatQuery(query) + .withSession(session) + .withSetupQuery(setupQuery) + .withCleanupQuery(cleanupQuery) + .experiencing(TASK_GET_RESULTS_REQUEST_FAILURE) + .at(boundaryDistributedStage()) + .finishesSuccessfully(); + } + + private FailureRecoveryAssert assertThatQuery(String query) + { + return new FailureRecoveryAssert(query); + } + + protected class FailureRecoveryAssert + { + private final String query; + private Session session = getQueryRunner().getDefaultSession(); + private Function stageSelector; + private Optional failureType = Optional.empty(); + private Optional errorType = Optional.empty(); + private Optional setup = Optional.empty(); + private Optional cleanup = Optional.empty(); + + public FailureRecoveryAssert(String query) + { + this.query = requireNonNull(query, "query is null"); + } + + public FailureRecoveryAssert withSession(Optional session) + { + requireNonNull(session, "session is null"); + session.ifPresent(value -> this.session = value); + return this; + } + + public FailureRecoveryAssert withSetupQuery(Optional query) + { + setup = requireNonNull(query, "query is null"); + return this; + } + + public FailureRecoveryAssert withCleanupQuery(Optional query) + { + cleanup = requireNonNull(query, "query is null"); + return this; + } + + public FailureRecoveryAssert experiencing(InjectedFailureType failureType) + { + return experiencing(failureType, Optional.empty()); + } + + public FailureRecoveryAssert experiencing(InjectedFailureType failureType, Optional errorType) + { + this.failureType = Optional.of(requireNonNull(failureType, "failureType is null")); + this.errorType = requireNonNull(errorType, "errorType is null"); + if (failureType == TASK_FAILURE) { + checkArgument(errorType.isPresent(), "error type must be present when injection type is task failure"); + } + else { + checkArgument(errorType.isEmpty(), "error type must not be present when injection type is not task failure"); + } + return this; + } + + public FailureRecoveryAssert at(Function stageSelector) + { + this.stageSelector = requireNonNull(stageSelector, "stageSelector is null"); + return this; + } + + private ExecutionResult executeExpected() + { + return execute(query, Optional.empty()); + } + + private ExecutionResult executeActual(MaterializedResult expected) + { + requireNonNull(stageSelector, "stageSelector must be set"); + int stageId = stageSelector.apply(expected); + String token = UUID.randomUUID().toString(); + failureType.ifPresent(failure -> getQueryRunner().injectTaskFailure( + token, + stageId, + 0, + 0, + failure, + errorType)); + + ExecutionResult actual = execute(query, Optional.of(token)); + assertEquals(getStageStats(actual.getQueryResult(), stageId).getFailedTasks(), failureType.isPresent() ? 1 : 0); + return actual; + } + + private ExecutionResult execute(String query, Optional traceToken) + { + String tableName = "table_" + randomTableSuffix(); + setup.ifPresent(sql -> getQueryRunner().execute(session, resolveTableName(sql, tableName))); + + MaterializedResult result = null; + RuntimeException failure = null; + try { + Session sessionWithToken = Session.builder(session) + .setTraceToken(traceToken) + .build(); + result = getQueryRunner().execute(sessionWithToken, resolveTableName(query, tableName)); + } + catch (RuntimeException e) { + failure = e; + } + + Optional updatedTableContent = Optional.empty(); + if (result != null && result.getUpdateCount().isPresent()) { + updatedTableContent = Optional.of(getQueryRunner().execute(session, "SELECT * FROM " + tableName)); + } + + Optional updatedTableStatistics = Optional.empty(); + if (result != null && result.getUpdateType().isPresent() && result.getUpdateType().get().equals("ANALYZE")) { + updatedTableStatistics = Optional.of(getQueryRunner().execute(session, "SHOW STATS FOR " + tableName)); + } + + try { + cleanup.ifPresent(sql -> getQueryRunner().execute(session, resolveTableName(sql, tableName))); + } + catch (RuntimeException e) { + if (failure == null) { + failure = e; + } + else if (failure != e) { + failure.addSuppressed(e); + } + } + + if (failure != null) { + throw failure; + } + + return new ExecutionResult(result, updatedTableContent, updatedTableStatistics); + } + + public void finishesSuccessfully() + { + ExecutionResult expected = executeExpected(); + MaterializedResult expectedQueryResult = expected.getQueryResult(); + ExecutionResult actual = executeActual(expectedQueryResult); + MaterializedResult actualQueryResult = actual.getQueryResult(); + + boolean isAnalyze = expectedQueryResult.getUpdateType().isPresent() && expectedQueryResult.getUpdateType().get().equals("ANALYZE"); + boolean isUpdate = expectedQueryResult.getUpdateCount().isPresent(); + boolean isExplain = query.trim().toUpperCase(ENGLISH).startsWith("EXPLAIN"); + if (isAnalyze) { + assertEquals(actualQueryResult.getUpdateCount(), expectedQueryResult.getUpdateCount()); + assertThat(expected.getUpdatedTableStatistics()).isPresent(); + assertThat(actual.getUpdatedTableStatistics()).isPresent(); + + MaterializedResult expectedUpdatedTableStatistics = expected.getUpdatedTableStatistics().get(); + MaterializedResult actualUpdatedTableStatistics = actual.getUpdatedTableStatistics().get(); + assertEqualsIgnoreOrder(actualUpdatedTableStatistics, expectedUpdatedTableStatistics, "For query: \n " + query); + } + else if (isUpdate) { + assertEquals(actualQueryResult.getUpdateCount(), expectedQueryResult.getUpdateCount()); + assertThat(expected.getUpdatedTableContent()).isPresent(); + assertThat(actual.getUpdatedTableContent()).isPresent(); + MaterializedResult expectedUpdatedTableContent = expected.getUpdatedTableContent().get(); + MaterializedResult actualUpdatedTableContent = actual.getUpdatedTableContent().get(); + assertEqualsIgnoreOrder(actualUpdatedTableContent, expectedUpdatedTableContent, "For query: \n " + query); + } + else if (isExplain) { + assertEquals(actualQueryResult.getRowCount(), expectedQueryResult.getRowCount()); + } + else { + assertEqualsIgnoreOrder(actualQueryResult, expectedQueryResult, "For query: \n " + query); + } + } + + public AbstractThrowableAssert failsWithErrorThat() + { + ExecutionResult expected = executeExpected(); + return assertThatThrownBy(() -> executeActual(expected.getQueryResult())); + } + + private String resolveTableName(String query, String tableName) + { + return query.replaceAll("
", tableName); + } + } + + private static class ExecutionResult + { + private final MaterializedResult queryResult; + private final Optional updatedTableContent; + private final Optional updatedTableStatistics; + + private ExecutionResult( + MaterializedResult queryResult, + Optional updatedTableContent, + Optional updatedTableStatistics) + { + this.queryResult = requireNonNull(queryResult, "queryResult is null"); + this.updatedTableContent = requireNonNull(updatedTableContent, "updatedTableContent is null"); + this.updatedTableStatistics = requireNonNull(updatedTableStatistics, "updatedTableStatistics is null"); + } + + public MaterializedResult getQueryResult() + { + return queryResult; + } + + public Optional getUpdatedTableContent() + { + return updatedTableContent; + } + + public Optional getUpdatedTableStatistics() + { + return updatedTableStatistics; + } + } + + protected static Function rootStage() + { + return (result) -> parseInt(getRootStage(result).getStageId()); + } + + protected static Function boundaryCoordinatorStage() + { + return (result) -> findStageId(result, stage -> stage.isCoordinatorOnly() && stage.getSubStages().stream().noneMatch(StageStats::isCoordinatorOnly)); + } + + protected static Function boundaryDistributedStage() + { + return (result) -> { + StageStats rootStage = getRootStage(result); + if (!rootStage.isCoordinatorOnly()) { + return parseInt(rootStage.getStageId()); + } + + StageStats boundaryCoordinatorStage = findStage(result, stage -> stage.isCoordinatorOnly() && stage.getSubStages().stream().noneMatch(StageStats::isCoordinatorOnly)); + StageStats boundaryDistributedStage = boundaryCoordinatorStage.getSubStages().get(ThreadLocalRandom.current().nextInt(boundaryCoordinatorStage.getSubStages().size())); + return parseInt(boundaryDistributedStage.getStageId()); + }; + } + + protected static Function intermediateDistributedStage() + { + return (result) -> findStageId(result, stage -> !stage.isCoordinatorOnly() && !stage.getSubStages().isEmpty()); + } + + protected static Function leafStage() + { + return (result) -> findStageId(result, stage -> stage.getSubStages().isEmpty()); + } + + private static int findStageId(MaterializedResult result, Predicate predicate) + { + return parseInt(findStage(result, predicate).getStageId()); + } + + private static StageStats findStage(MaterializedResult result, Predicate predicate) + { + List stages = stream(Traverser.forTree(StageStats::getSubStages).breadthFirst(getRootStage(result))) + .filter(predicate) + .collect(toImmutableList()); + if (stages.isEmpty()) { + throw new IllegalArgumentException("stage not found"); + } + return stages.get(ThreadLocalRandom.current().nextInt(stages.size())); + } + + private static StageStats getStageStats(MaterializedResult result, int stageId) + { + return stream(Traverser.forTree(StageStats::getSubStages).breadthFirst(getRootStage(result))) + .filter(stageStats -> parseInt(stageStats.getStageId()) == stageId) + .findFirst() + .orElseThrow(() -> new IllegalArgumentException("stage stats not found: " + stageId)); + } + + private static StageStats getRootStage(MaterializedResult result) + { + StatementStats statementStats = result.getStatementStats().orElseThrow(() -> new IllegalArgumentException("statement stats is not present")); + return requireNonNull(statementStats.getRootStage(), "root stage is null"); + } +} diff --git a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestingTrinoClient.java b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestingTrinoClient.java index e00bfcbfcd7b..092014bb05b2 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestingTrinoClient.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestingTrinoClient.java @@ -111,6 +111,7 @@ public ResultWithQueryId execute(Session session, @Language("SQL") String sql } resultsSession.setWarnings(results.getWarnings()); + resultsSession.setStatementStats(results.getStats()); T result = resultsSession.build(client.getSetSessionProperties(), client.getResetSessionProperties()); return new ResultWithQueryId<>(new QueryId(results.getId()), result); diff --git a/testing/trino-testing/src/main/java/io/trino/testing/DistributedQueryRunner.java b/testing/trino-testing/src/main/java/io/trino/testing/DistributedQueryRunner.java index be96b3b76dfc..5949b2458563 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/DistributedQueryRunner.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/DistributedQueryRunner.java @@ -25,6 +25,7 @@ import io.trino.Session.SessionBuilder; import io.trino.connector.CatalogName; import io.trino.cost.StatsCalculator; +import io.trino.execution.FailureInjector.InjectedFailureType; import io.trino.execution.QueryManager; import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.AllNodes; @@ -35,6 +36,7 @@ import io.trino.server.BasicQueryInfo; import io.trino.server.SessionPropertyDefaults; import io.trino.server.testing.TestingTrinoServer; +import io.trino.spi.ErrorType; import io.trino.spi.Plugin; import io.trino.spi.QueryId; import io.trino.spi.eventlistener.EventListener; @@ -535,6 +537,26 @@ public Lock getExclusiveLock() return lock.writeLock(); } + @Override + public void injectTaskFailure( + String traceToken, + int stageId, + int partitionId, + int attemptId, + InjectedFailureType injectionType, + Optional errorType) + { + for (TestingTrinoServer server : servers) { + server.injectTaskFailure( + traceToken, + stageId, + partitionId, + attemptId, + injectionType, + errorType); + } + } + @Override public final void close() { diff --git a/testing/trino-testing/src/main/java/io/trino/testing/ResultsSession.java b/testing/trino-testing/src/main/java/io/trino/testing/ResultsSession.java index c191763816d4..7b98dd3ef2ee 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/ResultsSession.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/ResultsSession.java @@ -15,6 +15,7 @@ import io.trino.client.QueryData; import io.trino.client.QueryStatusInfo; +import io.trino.client.StatementStats; import io.trino.client.Warning; import java.util.List; @@ -35,6 +36,8 @@ default void setUpdateCount(long count) default void setWarnings(List warnings) {} + default void setStatementStats(StatementStats statementStats) {} + void addResults(QueryStatusInfo statusInfo, QueryData data); T build(Map setSessionProperties, Set resetSessionProperties); diff --git a/testing/trino-testing/src/main/java/io/trino/testing/StandaloneQueryRunner.java b/testing/trino-testing/src/main/java/io/trino/testing/StandaloneQueryRunner.java index ee2016f03f23..95f1573830de 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/StandaloneQueryRunner.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/StandaloneQueryRunner.java @@ -17,12 +17,14 @@ import io.trino.Session; import io.trino.connector.CatalogName; import io.trino.cost.StatsCalculator; +import io.trino.execution.FailureInjector.InjectedFailureType; import io.trino.metadata.AllNodes; import io.trino.metadata.InternalNode; import io.trino.metadata.Metadata; import io.trino.metadata.QualifiedObjectName; import io.trino.metadata.SqlFunction; import io.trino.server.testing.TestingTrinoServer; +import io.trino.spi.ErrorType; import io.trino.spi.Plugin; import io.trino.split.PageSourceManager; import io.trino.split.SplitManager; @@ -34,6 +36,7 @@ import java.io.IOException; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReadWriteLock; @@ -259,6 +262,24 @@ public Lock getExclusiveLock() return lock.writeLock(); } + @Override + public void injectTaskFailure( + String traceToken, + int stageId, + int partitionId, + int attemptId, + InjectedFailureType injectionType, + Optional errorType) + { + server.injectTaskFailure( + traceToken, + stageId, + partitionId, + attemptId, + injectionType, + errorType); + } + private static TestingTrinoServer createTestingTrinoServer() { return TestingTrinoServer.builder() diff --git a/testing/trino-testing/src/main/java/io/trino/testing/TestingTrinoClient.java b/testing/trino-testing/src/main/java/io/trino/testing/TestingTrinoClient.java index b8a47b702705..833b3380bc5b 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/TestingTrinoClient.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/TestingTrinoClient.java @@ -21,6 +21,7 @@ import io.trino.client.QueryStatusInfo; import io.trino.client.Row; import io.trino.client.RowField; +import io.trino.client.StatementStats; import io.trino.client.Warning; import io.trino.server.testing.TestingTrinoServer; import io.trino.spi.type.ArrayType; @@ -126,6 +127,7 @@ private class MaterializedResultSession private final AtomicReference> updateType = new AtomicReference<>(Optional.empty()); private final AtomicReference updateCount = new AtomicReference<>(OptionalLong.empty()); private final AtomicReference> warnings = new AtomicReference<>(ImmutableList.of()); + private final AtomicReference> statementStats = new AtomicReference<>(Optional.empty()); @Override public void setUpdateType(String type) @@ -145,6 +147,12 @@ public void setWarnings(List warnings) this.warnings.set(warnings); } + @Override + public void setStatementStats(StatementStats statementStats) + { + this.statementStats.set(Optional.of(statementStats)); + } + @Override public void addResults(QueryStatusInfo statusInfo, QueryData data) { @@ -169,7 +177,8 @@ public MaterializedResult build(Map setSessionProperties, Set assertEquals(queryRunner.getCoordinator().getFullQueryInfo(queryId).getOutputStage().get().getState(), FLUSHING)); + () -> assertEquals(queryRunner.getCoordinator().getFullQueryInfo(queryId).getOutputStage().get().getState(), StageState.RUNNING)); - // wait for the sub stages to go to cancelled state + // wait for the sub stages to go to pending state assertEventually( new Duration(10, SECONDS), - () -> assertEquals(queryRunner.getCoordinator().getFullQueryInfo(queryId).getOutputStage().get().getSubStages().get(0).getState(), CANCELED)); + () -> assertEquals(queryRunner.getCoordinator().getFullQueryInfo(queryId).getOutputStage().get().getSubStages().get(0).getState(), StageState.PENDING)); QueryInfo queryInfo = queryRunner.getCoordinator().getFullQueryInfo(queryId); assertEquals(queryInfo.getState(), RUNNING); - assertEquals(queryInfo.getOutputStage().get().getState(), FLUSHING); + assertEquals(queryInfo.getOutputStage().get().getState(), StageState.RUNNING); assertEquals(queryInfo.getOutputStage().get().getSubStages().size(), 1); - assertEquals(queryInfo.getOutputStage().get().getSubStages().get(0).getState(), CANCELED); + assertEquals(queryInfo.getOutputStage().get().getSubStages().get(0).getState(), StageState.PENDING); } @AfterClass(alwaysRun = true)