diff --git a/core/trino-main/src/main/java/io/trino/FeaturesConfig.java b/core/trino-main/src/main/java/io/trino/FeaturesConfig.java index 00901d3a90fe..7650f2128309 100644 --- a/core/trino-main/src/main/java/io/trino/FeaturesConfig.java +++ b/core/trino-main/src/main/java/io/trino/FeaturesConfig.java @@ -36,6 +36,7 @@ import java.util.List; import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.airlift.units.DataSize.Unit.GIGABYTE; import static io.airlift.units.DataSize.Unit.KILOBYTE; import static io.airlift.units.DataSize.Unit.MEGABYTE; import static io.trino.sql.analyzer.RegexLibrary.JONI; @@ -152,6 +153,9 @@ public class FeaturesConfig private Duration retryInitialDelay = new Duration(10, SECONDS); private Duration retryMaxDelay = new Duration(1, MINUTES); + private DataSize faultTolerantExecutionTargetTaskInputSize = DataSize.of(1, GIGABYTE); + private int faultTolerantExecutionTargetTaskSplitCount = 16; + public enum JoinReorderingStrategy { NONE, @@ -1169,4 +1173,30 @@ public FeaturesConfig setRetryMaxDelay(Duration retryMaxDelay) this.retryMaxDelay = retryMaxDelay; return this; } + + @NotNull + public DataSize getFaultTolerantExecutionTargetTaskInputSize() + { + return faultTolerantExecutionTargetTaskInputSize; + } + + @Config("fault-tolerant-execution-target-task-input-size") + public FeaturesConfig setFaultTolerantExecutionTargetTaskInputSize(DataSize faultTolerantExecutionTargetTaskInputSize) + { + this.faultTolerantExecutionTargetTaskInputSize = faultTolerantExecutionTargetTaskInputSize; + return this; + } + + @Min(1) + public int getFaultTolerantExecutionTargetTaskSplitCount() + { + return faultTolerantExecutionTargetTaskSplitCount; + } + + @Config("fault-tolerant-execution-target-task-split-count") + public FeaturesConfig setFaultTolerantExecutionTargetTaskSplitCount(int faultTolerantExecutionTargetTaskSplitCount) + { + this.faultTolerantExecutionTargetTaskSplitCount = faultTolerantExecutionTargetTaskSplitCount; + return this; + } } diff --git a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java index fc2976d4b2fb..cacc06ce3780 100644 --- a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java +++ b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java @@ -149,6 +149,8 @@ public final class SystemSessionProperties public static final String RETRY_ATTEMPTS = "retry_attempts"; public static final String RETRY_INITIAL_DELAY = "retry_initial_delay"; public static final String RETRY_MAX_DELAY = "retry_max_delay"; + public static final String FAULT_TOLERANT_EXECUTION_TARGET_TASK_INPUT_SIZE = "fault_tolerant_execution_target_task_input_size"; + public static final String FAULT_TOLERANT_EXECUTION_TARGET_TASK_SPLIT_COUNT = "fault_tolerant_execution_target_task_split_count"; private final List> sessionProperties; @@ -692,6 +694,16 @@ public SystemSessionProperties( RETRY_MAX_DELAY, "Maximum delay before initiating a retry attempt. Delay increases exponentially for each subsequent attempt starting from 'retry_initial_delay'", featuresConfig.getRetryMaxDelay(), + false), + dataSizeProperty( + FAULT_TOLERANT_EXECUTION_TARGET_TASK_INPUT_SIZE, + "Target size of all task inputs for a single fault tolerant task", + featuresConfig.getFaultTolerantExecutionTargetTaskInputSize(), + false), + integerProperty( + FAULT_TOLERANT_EXECUTION_TARGET_TASK_SPLIT_COUNT, + "Target number of splits for a single fault tolerant task", + featuresConfig.getFaultTolerantExecutionTargetTaskSplitCount(), false)); } @@ -1222,6 +1234,11 @@ public static RetryPolicy getRetryPolicy(Session session) throw new TrinoException(NOT_SUPPORTED, "Distributed sort is not supported with automatic retries enabled"); } } + if (retryPolicy == RetryPolicy.TASK) { + if (isGroupedExecutionEnabled(session) || isDynamicScheduleForGroupedExecution(session)) { + throw new TrinoException(NOT_SUPPORTED, "Grouped execution is not supported with task level retries enabled"); + } + } return retryPolicy; } @@ -1239,4 +1256,14 @@ public static Duration getRetryMaxDelay(Session session) { return session.getSystemProperty(RETRY_MAX_DELAY, Duration.class); } + + public static DataSize getFaultTolerantExecutionTargetTaskInputSize(Session session) + { + return session.getSystemProperty(FAULT_TOLERANT_EXECUTION_TARGET_TASK_INPUT_SIZE, DataSize.class); + } + + public static int getFaultTolerantExecutionTargetTaskSplitCount(Session session) + { + return session.getSystemProperty(FAULT_TOLERANT_EXECUTION_TARGET_TASK_SPLIT_COUNT, Integer.class); + } } diff --git a/core/trino-main/src/main/java/io/trino/exchange/ExchangeManagerModule.java b/core/trino-main/src/main/java/io/trino/exchange/ExchangeManagerModule.java new file mode 100644 index 000000000000..e05fe95d2de9 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/exchange/ExchangeManagerModule.java @@ -0,0 +1,28 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.exchange; + +import com.google.inject.Binder; +import com.google.inject.Module; +import com.google.inject.Scopes; + +public class ExchangeManagerModule + implements Module +{ + @Override + public void configure(Binder binder) + { + binder.bind(ExchangeManagerRegistry.class).in(Scopes.SINGLETON); + } +} diff --git a/core/trino-main/src/main/java/io/trino/exchange/ExchangeManagerRegistry.java b/core/trino-main/src/main/java/io/trino/exchange/ExchangeManagerRegistry.java new file mode 100644 index 000000000000..923ddbbc9897 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/exchange/ExchangeManagerRegistry.java @@ -0,0 +1,121 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.exchange; + +import io.airlift.log.Logger; +import io.trino.metadata.HandleResolver; +import io.trino.spi.TrinoException; +import io.trino.spi.classloader.ThreadContextClassLoader; +import io.trino.spi.exchange.ExchangeManager; +import io.trino.spi.exchange.ExchangeManagerFactory; + +import javax.annotation.concurrent.GuardedBy; +import javax.inject.Inject; + +import java.io.File; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Strings.isNullOrEmpty; +import static io.airlift.configuration.ConfigurationLoader.loadPropertiesFrom; +import static io.trino.spi.StandardErrorCode.EXCHANGE_MANAGER_NOT_CONFIGURED; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class ExchangeManagerRegistry +{ + private static final Logger log = Logger.get(ExchangeManagerRegistry.class); + + private static final File CONFIG_FILE = new File("etc/exchange-manager.properties"); + private static final String EXCHANGE_MANAGER_NAME_PROPERTY = "exchange-manager.name"; + + private final HandleResolver handleResolver; + + private final Map exchangeManagerFactories = new ConcurrentHashMap<>(); + + @GuardedBy("this") + private volatile ExchangeManager exchangeManager; + + @Inject + public ExchangeManagerRegistry(HandleResolver handleResolver) + { + this.handleResolver = requireNonNull(handleResolver, "handleResolver is null"); + } + + public void addExchangeManagerFactory(ExchangeManagerFactory factory) + { + requireNonNull(factory, "factory is null"); + if (exchangeManagerFactories.putIfAbsent(factory.getName(), factory) != null) { + throw new IllegalArgumentException(format("Exchange manager factory '%s' is already registered", factory.getName())); + } + } + + public void loadExchangeManager() + { + if (!CONFIG_FILE.exists()) { + log.info("Exchange manager configuration file is not present: %s", CONFIG_FILE.getAbsoluteFile()); + return; + } + + Map properties = loadProperties(CONFIG_FILE); + String name = properties.remove(EXCHANGE_MANAGER_NAME_PROPERTY); + checkArgument(!isNullOrEmpty(name), "Exchange manager configuration %s does not contain %s", CONFIG_FILE, EXCHANGE_MANAGER_NAME_PROPERTY); + + loadExchangeManager(name, properties); + } + + public synchronized void loadExchangeManager(String name, Map properties) + { + log.info("-- Loading exchange manager %s --", name); + + checkState(exchangeManager == null, "exchangeManager is already loaded"); + + ExchangeManagerFactory factory = exchangeManagerFactories.get(name); + checkArgument(factory != null, "Exchange manager factory '%s' is not registered. Available factories: %s", name, exchangeManagerFactories.keySet()); + + ExchangeManager exchangeManager; + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(factory.getClass().getClassLoader())) { + exchangeManager = factory.create(properties); + } + handleResolver.setExchangeManagerHandleResolver(factory.getHandleResolver()); + + log.info("-- Loaded exchange manager %s --", name); + + this.exchangeManager = exchangeManager; + } + + public ExchangeManager getExchangeManager() + { + ExchangeManager exchangeManager = this.exchangeManager; + if (exchangeManager == null) { + throw new TrinoException(EXCHANGE_MANAGER_NOT_CONFIGURED, "Exchange manager is not configured"); + } + return exchangeManager; + } + + private static Map loadProperties(File configFile) + { + try { + return new HashMap<>(loadPropertiesFrom(configFile.getPath())); + } + catch (IOException e) { + throw new UncheckedIOException("Failed to read configuration file: " + configFile, e); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/ExecutionFailureInfo.java b/core/trino-main/src/main/java/io/trino/execution/ExecutionFailureInfo.java index ef9d29b19199..be5e1934cc7b 100644 --- a/core/trino-main/src/main/java/io/trino/execution/ExecutionFailureInfo.java +++ b/core/trino-main/src/main/java/io/trino/execution/ExecutionFailureInfo.java @@ -18,6 +18,7 @@ import com.google.common.collect.ImmutableList; import io.trino.client.ErrorLocation; import io.trino.client.FailureInfo; +import io.trino.failuredetector.FailureDetector; import io.trino.spi.ErrorCode; import io.trino.spi.HostAddress; @@ -29,6 +30,8 @@ import java.util.regex.Pattern; import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.failuredetector.FailureDetector.State.GONE; +import static io.trino.spi.StandardErrorCode.REMOTE_HOST_GONE; import static java.util.Objects.requireNonNull; @Immutable @@ -175,4 +178,21 @@ else if (matcher.group(4) != null) { } return new StackTraceElement("Unknown", stack, null, -1); } + + public static ExecutionFailureInfo rewriteTransportFailure(FailureDetector failureDetector, ExecutionFailureInfo executionFailureInfo) + { + if (executionFailureInfo.getRemoteHost() == null || failureDetector.getState(executionFailureInfo.getRemoteHost()) != GONE) { + return executionFailureInfo; + } + + return new ExecutionFailureInfo( + executionFailureInfo.getType(), + executionFailureInfo.getMessage(), + executionFailureInfo.getCause(), + executionFailureInfo.getSuppressed(), + executionFailureInfo.getStack(), + executionFailureInfo.getErrorLocation(), + REMOTE_HOST_GONE.toErrorCode(), + executionFailureInfo.getRemoteHost()); + } } diff --git a/core/trino-main/src/main/java/io/trino/execution/QueryStateMachine.java b/core/trino-main/src/main/java/io/trino/execution/QueryStateMachine.java index 90e09fad9038..525b6ea67e0a 100644 --- a/core/trino-main/src/main/java/io/trino/execution/QueryStateMachine.java +++ b/core/trino-main/src/main/java/io/trino/execution/QueryStateMachine.java @@ -1184,7 +1184,7 @@ private static QueryStats pruneQueryStats(QueryStats queryStats) queryStats.getPhysicalWrittenDataSize(), queryStats.getStageGcStatistics(), queryStats.getDynamicFiltersStats(), - ImmutableList.of()); // Remove the operator summaries as OperatorInfo (especially ExchangeClientStatus) can hold onto a large amount of memory + ImmutableList.of()); // Remove the operator summaries as OperatorInfo (especially DirectExchangeClientStatus) can hold onto a large amount of memory } public static class QueryOutputManager diff --git a/core/trino-main/src/main/java/io/trino/execution/TaskSource.java b/core/trino-main/src/main/java/io/trino/execution/SplitAssignment.java similarity index 68% rename from core/trino-main/src/main/java/io/trino/execution/TaskSource.java rename to core/trino-main/src/main/java/io/trino/execution/SplitAssignment.java index 074b9c6be67a..2f77bcb821f0 100644 --- a/core/trino-main/src/main/java/io/trino/execution/TaskSource.java +++ b/core/trino-main/src/main/java/io/trino/execution/SplitAssignment.java @@ -24,7 +24,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; -public class TaskSource +public class SplitAssignment { private final PlanNodeId planNodeId; private final Set splits; @@ -32,7 +32,7 @@ public class TaskSource private final boolean noMoreSplits; @JsonCreator - public TaskSource( + public SplitAssignment( @JsonProperty("planNodeId") PlanNodeId planNodeId, @JsonProperty("splits") Set splits, @JsonProperty("noMoreSplitsForLifespan") Set noMoreSplitsForLifespan, @@ -44,7 +44,7 @@ public TaskSource( this.noMoreSplits = noMoreSplits; } - public TaskSource(PlanNodeId planNodeId, Set splits, boolean noMoreSplits) + public SplitAssignment(PlanNodeId planNodeId, Set splits, boolean noMoreSplits) { this(planNodeId, splits, ImmutableSet.of(), noMoreSplits); } @@ -73,43 +73,43 @@ public boolean isNoMoreSplits() return noMoreSplits; } - public TaskSource update(TaskSource source) + public SplitAssignment update(SplitAssignment assignment) { - checkArgument(planNodeId.equals(source.getPlanNodeId()), "Expected source %s, but got source %s", planNodeId, source.getPlanNodeId()); + checkArgument(planNodeId.equals(assignment.getPlanNodeId()), "Expected assignment for node %s, but got assignment for node %s", planNodeId, assignment.getPlanNodeId()); - if (isNewer(source)) { - // assure the new source is properly formed - // we know that either the new source one has new splits and/or it is marking the source as closed - checkArgument(!noMoreSplits || splits.containsAll(source.getSplits()), "Source %s has new splits, but no more splits already set", planNodeId); + if (isNewer(assignment)) { + // assure the new assignment is properly formed + // we know that either the new assignment one has new splits and/or it is marking the assignment as closed + checkArgument(!noMoreSplits || splits.containsAll(assignment.getSplits()), "Assignment %s has new splits, but no more splits already set", planNodeId); Set newSplits = ImmutableSet.builder() .addAll(splits) - .addAll(source.getSplits()) + .addAll(assignment.getSplits()) .build(); Set newNoMoreSplitsForDriverGroup = ImmutableSet.builder() .addAll(noMoreSplitsForLifespan) - .addAll(source.getNoMoreSplitsForLifespan()) + .addAll(assignment.getNoMoreSplitsForLifespan()) .build(); - return new TaskSource( + return new SplitAssignment( planNodeId, newSplits, newNoMoreSplitsForDriverGroup, - source.isNoMoreSplits()); + assignment.isNoMoreSplits()); } else { - // the specified source is older than this one + // the specified assignment is older than this one return this; } } - private boolean isNewer(TaskSource source) + private boolean isNewer(SplitAssignment assignment) { - // the specified source is newer if it changes the no more + // the specified assignment is newer if it changes the no more // splits flag or if it contains new splits - return (!noMoreSplits && source.isNoMoreSplits()) || - (!noMoreSplitsForLifespan.containsAll(source.getNoMoreSplitsForLifespan())) || - (!splits.containsAll(source.getSplits())); + return (!noMoreSplits && assignment.isNoMoreSplits()) || + (!noMoreSplitsForLifespan.containsAll(assignment.getNoMoreSplitsForLifespan())) || + (!splits.containsAll(assignment.getSplits())); } @Override diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java b/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java index bfae1a6636f5..5066d8c56092 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java @@ -22,12 +22,14 @@ import io.trino.connector.CatalogName; import io.trino.cost.CostCalculator; import io.trino.cost.StatsCalculator; +import io.trino.exchange.ExchangeManagerRegistry; import io.trino.execution.QueryPreparer.PreparedQuery; import io.trino.execution.StateMachine.StateChangeListener; import io.trino.execution.scheduler.ExecutionPolicy; import io.trino.execution.scheduler.NodeScheduler; import io.trino.execution.scheduler.SplitSchedulerStats; import io.trino.execution.scheduler.SqlQueryScheduler; +import io.trino.execution.scheduler.TaskSourceFactory; import io.trino.execution.warnings.WarningCollector; import io.trino.failuredetector.FailureDetector; import io.trino.memory.VersionedMemoryPoolId; @@ -119,6 +121,8 @@ public class SqlQueryExecution private final TableExecuteContextManager tableExecuteContextManager; private final TypeAnalyzer typeAnalyzer; private final TaskManager coordinatorTaskManager; + private final ExchangeManagerRegistry exchangeManagerRegistry; + private final TaskSourceFactory taskSourceFactory; private SqlQueryExecution( PreparedQuery preparedQuery, @@ -146,7 +150,9 @@ private SqlQueryExecution( WarningCollector warningCollector, TableExecuteContextManager tableExecuteContextManager, TypeAnalyzer typeAnalyzer, - TaskManager coordinatorTaskManager) + TaskManager coordinatorTaskManager, + ExchangeManagerRegistry exchangeManagerRegistry, + TaskSourceFactory taskSourceFactory) { try (SetThreadName ignored = new SetThreadName("Query-%s", stateMachine.getQueryId())) { this.slug = requireNonNull(slug, "slug is null"); @@ -203,6 +209,8 @@ private SqlQueryExecution( this.remoteTaskFactory = new MemoryTrackingRemoteTaskFactory(requireNonNull(remoteTaskFactory, "remoteTaskFactory is null"), stateMachine); this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); this.coordinatorTaskManager = requireNonNull(coordinatorTaskManager, "coordinatorTaskManager is null"); + this.exchangeManagerRegistry = requireNonNull(exchangeManagerRegistry, "exchangeManagerRegistry is null"); + this.taskSourceFactory = requireNonNull(taskSourceFactory, "taskSourceFactory is null"); } } @@ -516,7 +524,9 @@ private void planDistribution(PlanRoot plan) tableExecuteContextManager, metadata, splitSourceFactory, - coordinatorTaskManager); + coordinatorTaskManager, + exchangeManagerRegistry, + taskSourceFactory); queryScheduler.set(scheduler); @@ -703,6 +713,8 @@ public static class SqlQueryExecutionFactory private final TableExecuteContextManager tableExecuteContextManager; private final TypeAnalyzer typeAnalyzer; private final TaskManager coordinatorTaskManager; + private final ExchangeManagerRegistry exchangeManagerRegistry; + private final TaskSourceFactory taskSourceFactory; @Inject SqlQueryExecutionFactory( @@ -727,7 +739,9 @@ public static class SqlQueryExecutionFactory DynamicFilterService dynamicFilterService, TableExecuteContextManager tableExecuteContextManager, TypeAnalyzer typeAnalyzer, - TaskManager coordinatorTaskManager) + TaskManager coordinatorTaskManager, + ExchangeManagerRegistry exchangeManagerRegistry, + TaskSourceFactory taskSourceFactory) { requireNonNull(config, "config is null"); this.schedulerStats = requireNonNull(schedulerStats, "schedulerStats is null"); @@ -752,6 +766,8 @@ public static class SqlQueryExecutionFactory this.tableExecuteContextManager = requireNonNull(tableExecuteContextManager, "tableExecuteContextManager is null"); this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); this.coordinatorTaskManager = requireNonNull(coordinatorTaskManager, "coordinatorTaskManager is null"); + this.exchangeManagerRegistry = requireNonNull(exchangeManagerRegistry, "exchangeManagerRegistry is null"); + this.taskSourceFactory = requireNonNull(taskSourceFactory, "taskSourceFactory is null"); } @Override @@ -791,7 +807,9 @@ public QueryExecution createQueryExecution( warningCollector, tableExecuteContextManager, typeAnalyzer, - coordinatorTaskManager); + coordinatorTaskManager, + exchangeManagerRegistry, + taskSourceFactory); } } } diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlTask.java b/core/trino-main/src/main/java/io/trino/execution/SqlTask.java index afdf36f9d3e4..28f49afe1272 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlTask.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlTask.java @@ -23,6 +23,7 @@ import io.airlift.units.DataSize; import io.airlift.units.Duration; import io.trino.Session; +import io.trino.exchange.ExchangeManagerRegistry; import io.trino.execution.DynamicFiltersCollector.VersionedDynamicFilterDomains; import io.trino.execution.StateMachine.StateChangeListener; import io.trino.execution.buffer.BufferResult; @@ -103,9 +104,10 @@ public static SqlTask createSqlTask( Consumer onDone, DataSize maxBufferSize, DataSize maxBroadcastBufferSize, + ExchangeManagerRegistry exchangeManagerRegistry, CounterStat failedTasks) { - SqlTask sqlTask = new SqlTask(taskId, location, nodeId, queryContext, sqlTaskExecutionFactory, taskNotificationExecutor, maxBufferSize, maxBroadcastBufferSize); + SqlTask sqlTask = new SqlTask(taskId, location, nodeId, queryContext, sqlTaskExecutionFactory, taskNotificationExecutor, maxBufferSize, maxBroadcastBufferSize, exchangeManagerRegistry); sqlTask.initialize(onDone, failedTasks); return sqlTask; } @@ -118,7 +120,8 @@ private SqlTask( SqlTaskExecutionFactory sqlTaskExecutionFactory, ExecutorService taskNotificationExecutor, DataSize maxBufferSize, - DataSize maxBroadcastBufferSize) + DataSize maxBroadcastBufferSize, + ExchangeManagerRegistry exchangeManagerRegistry) { this.taskId = requireNonNull(taskId, "taskId is null"); this.taskInstanceId = UUID.randomUUID().toString(); @@ -138,7 +141,8 @@ private SqlTask( // Pass a memory context supplier instead of a memory context to the output buffer, // because we haven't created the task context that holds the memory context yet. () -> queryContext.getTaskContextByTaskId(taskId).localSystemMemoryContext(), - () -> notifyStatusChanged()); + () -> notifyStatusChanged(), + exchangeManagerRegistry); taskStateMachine = new TaskStateMachine(taskId, taskNotificationExecutor); } @@ -421,7 +425,7 @@ public synchronized ListenableFuture getTaskInfo(long callersCurrentVe public TaskInfo updateTask( Session session, Optional fragment, - List sources, + List splitAssignments, OutputBuffers outputBuffers, Map dynamicFilterDomains) { @@ -458,7 +462,7 @@ public TaskInfo updateTask( } if (taskExecution != null) { - taskExecution.addSources(sources); + taskExecution.addSplitAssignments(splitAssignments); taskExecution.getTaskContext().addDynamicFilter(dynamicFilterDomains); } } diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java b/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java index 15529df5aa74..c6e0877e8621 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java @@ -130,7 +130,7 @@ public class SqlTaskExecution // guarded for update only @GuardedBy("this") - private final ConcurrentMap unpartitionedSources = new ConcurrentHashMap<>(); + private final ConcurrentMap unpartitionedSplitAssignments = new ConcurrentHashMap<>(); @GuardedBy("this") private long maxAcknowledgedSplit = Long.MIN_VALUE; @@ -283,17 +283,17 @@ public TaskContext getTaskContext() return taskContext; } - public void addSources(List sources) + public void addSplitAssignments(List splitAssignments) { - requireNonNull(sources, "sources is null"); - checkState(!Thread.holdsLock(this), "Cannot add sources while holding a lock on the %s", getClass().getSimpleName()); + requireNonNull(splitAssignments, "splitAssignments is null"); + checkState(!Thread.holdsLock(this), "Cannot add split assignments while holding a lock on the %s", getClass().getSimpleName()); try (SetThreadName ignored = new SetThreadName("Task-%s", taskId)) { - // update our record of sources and schedule drivers for new partitioned splits - Map updatedUnpartitionedSources = updateSources(sources); + // update our record of split assignments and schedule drivers for new partitioned splits + Map updatedUnpartitionedSources = updateSplitAssignments(splitAssignments); // tell existing drivers about the new splits; it is safe to update drivers - // multiple times and out of order because sources contain full record of + // multiple times and out of order because split assignments contain full record of // the unpartitioned splits for (WeakReference driverReference : drivers) { Driver driver = driverReference.get(); @@ -308,11 +308,11 @@ public void addSources(List sources) if (sourceId.isEmpty()) { continue; } - TaskSource sourceUpdate = updatedUnpartitionedSources.get(sourceId.get()); - if (sourceUpdate == null) { + SplitAssignment splitAssignmentUpdate = updatedUnpartitionedSources.get(sourceId.get()); + if (splitAssignmentUpdate == null) { continue; } - driver.updateSource(sourceUpdate); + driver.updateSplitAssignment(splitAssignmentUpdate); } // we may have transitioned to no more splits, so check for completion @@ -320,14 +320,14 @@ public void addSources(List sources) } } - private synchronized Map updateSources(List sources) + private synchronized Map updateSplitAssignments(List splitAssignments) { - Map updatedUnpartitionedSources = new HashMap<>(); + Map updatedUnpartitionedSplitAssignments = new HashMap<>(); // first remove any split that was already acknowledged long currentMaxAcknowledgedSplit = this.maxAcknowledgedSplit; - sources = sources.stream() - .map(source -> new TaskSource( + splitAssignments = splitAssignments.stream() + .map(source -> new SplitAssignment( source.getPlanNodeId(), source.getSplits().stream() .filter(scheduledSplit -> scheduledSplit.getSequenceId() > currentMaxAcknowledgedSplit) @@ -338,13 +338,13 @@ private synchronized Map updateSources(List source.isNoMoreSplits())) .collect(toList()); - // update task with new sources - for (TaskSource source : sources) { - if (driverRunnerFactoriesWithSplitLifeCycle.containsKey(source.getPlanNodeId())) { - schedulePartitionedSource(source); + // update task with new assignments + for (SplitAssignment assignment : splitAssignments) { + if (driverRunnerFactoriesWithSplitLifeCycle.containsKey(assignment.getPlanNodeId())) { + schedulePartitionedSource(assignment); } else { - scheduleUnpartitionedSource(source, updatedUnpartitionedSources); + scheduleUnpartitionedSource(assignment, updatedUnpartitionedSplitAssignments); } } @@ -354,12 +354,12 @@ private synchronized Map updateSources(List } // update maxAcknowledgedSplit - maxAcknowledgedSplit = sources.stream() + maxAcknowledgedSplit = splitAssignments.stream() .flatMap(source -> source.getSplits().stream()) .mapToLong(ScheduledSplit::getSequenceId) .max() .orElse(maxAcknowledgedSplit); - return updatedUnpartitionedSources; + return updatedUnpartitionedSplitAssignments; } @GuardedBy("this") @@ -387,9 +387,9 @@ private void mergeIntoPendingSplits(PlanNodeId planNodeId, Set s } } - private synchronized void schedulePartitionedSource(TaskSource sourceUpdate) + private synchronized void schedulePartitionedSource(SplitAssignment splitAssignmentUpdate) { - mergeIntoPendingSplits(sourceUpdate.getPlanNodeId(), sourceUpdate.getSplits(), sourceUpdate.getNoMoreSplitsForLifespan(), sourceUpdate.isNoMoreSplits()); + mergeIntoPendingSplits(splitAssignmentUpdate.getPlanNodeId(), splitAssignmentUpdate.getSplits(), splitAssignmentUpdate.getNoMoreSplitsForLifespan(), splitAssignmentUpdate.isNoMoreSplits()); while (true) { // SchedulingLifespanManager tracks how far each Lifespan has been scheduled. Here is an example. @@ -409,10 +409,10 @@ private synchronized void schedulePartitionedSource(TaskSource sourceUpdate) SchedulingLifespan schedulingLifespan = activeLifespans.next(); Lifespan lifespan = schedulingLifespan.getLifespan(); - // Continue using the example from above. Let's say the sourceUpdate adds some new splits for source node B. + // Continue using the example from above. Let's say the splitAssignmentUpdate adds some new splits for source node B. // // For lifespan 30, it could start new drivers and assign a pending split to each. - // Pending splits could include both pre-existing pending splits, and the new ones from sourceUpdate. + // Pending splits could include both pre-existing pending splits, and the new ones from splitAssignmentUpdate. // If there is enough driver slots to deplete pending splits, one of the below would happen. // * If it is marked that all splits for node B in lifespan 30 has been received, SchedulingLifespanManager // will be updated so that lifespan 30 now processes source node C. It will immediately start processing them. @@ -477,27 +477,27 @@ private synchronized void schedulePartitionedSource(TaskSource sourceUpdate) } } - if (sourceUpdate.isNoMoreSplits()) { - schedulingLifespanManager.noMoreSplits(sourceUpdate.getPlanNodeId()); + if (splitAssignmentUpdate.isNoMoreSplits()) { + schedulingLifespanManager.noMoreSplits(splitAssignmentUpdate.getPlanNodeId()); } } - private synchronized void scheduleUnpartitionedSource(TaskSource sourceUpdate, Map updatedUnpartitionedSources) + private synchronized void scheduleUnpartitionedSource(SplitAssignment splitAssignmentUpdate, Map updatedUnpartitionedSources) { // create new source - TaskSource newSource; - TaskSource currentSource = unpartitionedSources.get(sourceUpdate.getPlanNodeId()); - if (currentSource == null) { - newSource = sourceUpdate; + SplitAssignment newSplitAssignment; + SplitAssignment currentSplitAssignment = unpartitionedSplitAssignments.get(splitAssignmentUpdate.getPlanNodeId()); + if (currentSplitAssignment == null) { + newSplitAssignment = splitAssignmentUpdate; } else { - newSource = currentSource.update(sourceUpdate); + newSplitAssignment = currentSplitAssignment.update(splitAssignmentUpdate); } // only record new source if something changed - if (newSource != currentSource) { - unpartitionedSources.put(sourceUpdate.getPlanNodeId(), newSource); - updatedUnpartitionedSources.put(sourceUpdate.getPlanNodeId(), newSource); + if (newSplitAssignment != currentSplitAssignment) { + unpartitionedSplitAssignments.put(splitAssignmentUpdate.getPlanNodeId(), newSplitAssignment); + updatedUnpartitionedSources.put(splitAssignmentUpdate.getPlanNodeId(), newSplitAssignment); } } @@ -611,9 +611,9 @@ public synchronized Set getNoMoreSplits() noMoreSplits.add(entry.getKey()); } } - for (TaskSource taskSource : unpartitionedSources.values()) { - if (taskSource.isNoMoreSplits()) { - noMoreSplits.add(taskSource.getPlanNodeId()); + for (SplitAssignment splitAssignment : unpartitionedSplitAssignments.values()) { + if (splitAssignment.isNoMoreSplits()) { + noMoreSplits.add(splitAssignment.getPlanNodeId()); } } return noMoreSplits.build(); @@ -655,7 +655,7 @@ public String toString() return toStringHelper(this) .add("taskId", taskId) .add("remainingDrivers", status.getRemainingDriver()) - .add("unpartitionedSources", unpartitionedSources) + .add("unpartitionedSplitAssignments", unpartitionedSplitAssignments) .toString(); } @@ -947,15 +947,15 @@ public Driver createDriver(DriverContext driverContext, @Nullable ScheduledSplit if (partitionedSplit != null) { // TableScanOperator requires partitioned split to be added before the first call to process - driver.updateSource(new TaskSource(partitionedSplit.getPlanNodeId(), ImmutableSet.of(partitionedSplit), true)); + driver.updateSplitAssignment(new SplitAssignment(partitionedSplit.getPlanNodeId(), ImmutableSet.of(partitionedSplit), true)); } // add unpartitioned sources Optional sourceId = driver.getSourceId(); if (sourceId.isPresent()) { - TaskSource taskSource = unpartitionedSources.get(sourceId.get()); - if (taskSource != null) { - driver.updateSource(taskSource); + SplitAssignment splitAssignment = unpartitionedSplitAssignments.get(sourceId.get()); + if (splitAssignment != null) { + driver.updateSplitAssignment(splitAssignment); } } diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlTaskManager.java b/core/trino-main/src/main/java/io/trino/execution/SqlTaskManager.java index cfc69152e129..204bae3c2460 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlTaskManager.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlTaskManager.java @@ -28,6 +28,7 @@ import io.airlift.units.Duration; import io.trino.Session; import io.trino.event.SplitMonitor; +import io.trino.exchange.ExchangeManagerRegistry; import io.trino.execution.DynamicFiltersCollector.VersionedDynamicFilterDomains; import io.trino.execution.StateMachine.StateChangeListener; import io.trino.execution.buffer.BufferResult; @@ -133,7 +134,8 @@ public SqlTaskManager( NodeMemoryConfig nodeMemoryConfig, LocalSpillManager localSpillManager, NodeSpillConfig nodeSpillConfig, - GcMonitor gcMonitor) + GcMonitor gcMonitor, + ExchangeManagerRegistry exchangeManagerRegistry) { requireNonNull(nodeInfo, "nodeInfo is null"); requireNonNull(config, "config is null"); @@ -174,6 +176,7 @@ public SqlTaskManager( sqlTask -> finishedTaskStats.merge(sqlTask.getIoStats()), maxBufferSize, maxBroadcastBufferSize, + requireNonNull(exchangeManagerRegistry, "exchangeManagerRegistry is null"), failedTasks))); } @@ -373,12 +376,12 @@ public TaskInfo updateTask( Session session, TaskId taskId, Optional fragment, - List sources, + List splitAssignments, OutputBuffers outputBuffers, Map dynamicFilterDomains) { try { - return versionEmbedder.embedVersion(() -> doUpdateTask(session, taskId, fragment, sources, outputBuffers, dynamicFilterDomains)).call(); + return versionEmbedder.embedVersion(() -> doUpdateTask(session, taskId, fragment, splitAssignments, outputBuffers, dynamicFilterDomains)).call(); } catch (Exception e) { throwIfUnchecked(e); @@ -391,14 +394,14 @@ private TaskInfo doUpdateTask( Session session, TaskId taskId, Optional fragment, - List sources, + List splitAssignments, OutputBuffers outputBuffers, Map dynamicFilterDomains) { requireNonNull(session, "session is null"); requireNonNull(taskId, "taskId is null"); requireNonNull(fragment, "fragment is null"); - requireNonNull(sources, "sources is null"); + requireNonNull(splitAssignments, "splitAssignments is null"); requireNonNull(outputBuffers, "outputBuffers is null"); SqlTask sqlTask = tasks.getUnchecked(taskId); @@ -414,7 +417,7 @@ private TaskInfo doUpdateTask( } sqlTask.recordHeartbeat(); - return sqlTask.updateTask(session, fragment, sources, outputBuffers, dynamicFilterDomains); + return sqlTask.updateTask(session, fragment, splitAssignments, outputBuffers, dynamicFilterDomains); } @Override diff --git a/core/trino-main/src/main/java/io/trino/execution/TaskManager.java b/core/trino-main/src/main/java/io/trino/execution/TaskManager.java index 8ae16c746663..cc3abf39333f 100644 --- a/core/trino-main/src/main/java/io/trino/execution/TaskManager.java +++ b/core/trino-main/src/main/java/io/trino/execution/TaskManager.java @@ -85,14 +85,14 @@ public interface TaskManager void updateMemoryPoolAssignments(MemoryPoolAssignmentsRequest assignments); /** - * Updates the task plan, sources and output buffers. If the task does not + * Updates the task plan, splitAssignments and output buffers. If the task does not * already exist, it is created and then updated. */ TaskInfo updateTask( Session session, TaskId taskId, Optional fragment, - List sources, + List splitAssignments, OutputBuffers outputBuffers, Map dynamicFilterDomains); diff --git a/core/trino-main/src/main/java/io/trino/execution/buffer/ExternalExchangeOutputBuffer.java b/core/trino-main/src/main/java/io/trino/execution/buffer/ExternalExchangeOutputBuffer.java new file mode 100644 index 000000000000..efa1f4235c59 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/buffer/ExternalExchangeOutputBuffer.java @@ -0,0 +1,249 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.buffer; + +import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.ListenableFuture; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import io.airlift.units.DataSize; +import io.trino.execution.StateMachine; +import io.trino.memory.context.LocalMemoryContext; +import io.trino.spi.exchange.ExchangeSink; + +import java.util.List; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Supplier; + +import static com.google.common.base.Preconditions.checkArgument; +import static io.airlift.concurrent.MoreFutures.asVoid; +import static io.airlift.concurrent.MoreFutures.toListenableFuture; +import static io.trino.execution.buffer.BufferState.FAILED; +import static io.trino.execution.buffer.BufferState.FINISHED; +import static io.trino.execution.buffer.BufferState.FLUSHING; +import static io.trino.execution.buffer.BufferState.NO_MORE_BUFFERS; +import static io.trino.execution.buffer.BufferState.OPEN; +import static io.trino.execution.buffer.OutputBuffers.BufferType.EXTERNAL; +import static io.trino.execution.buffer.PagesSerdeUtil.SERIALIZED_PAGE_HEADER_SIZE; +import static io.trino.execution.buffer.PagesSerdeUtil.writeSerializedPage; +import static java.util.Objects.requireNonNull; + +public class ExternalExchangeOutputBuffer + implements OutputBuffer +{ + private final StateMachine state; + private final OutputBuffers outputBuffers; + private final ExchangeSink exchangeSink; + private final Supplier systemMemoryContextSupplier; + + private final AtomicLong peakMemoryUsage = new AtomicLong(); + private final AtomicLong totalPagesAdded = new AtomicLong(); + private final AtomicLong totalRowsAdded = new AtomicLong(); + + public ExternalExchangeOutputBuffer( + StateMachine state, + OutputBuffers outputBuffers, + ExchangeSink exchangeSink, + Supplier systemMemoryContextSupplier) + { + this.state = requireNonNull(state, "state is null"); + this.outputBuffers = requireNonNull(outputBuffers, "outputBuffers is null"); + checkArgument(outputBuffers.getType() == EXTERNAL, "Expected an EXTERNAL output buffer"); + this.exchangeSink = requireNonNull(exchangeSink, "exchangeSink is null"); + this.systemMemoryContextSupplier = requireNonNull(systemMemoryContextSupplier, "systemMemoryContextSupplier is null"); + + state.compareAndSet(OPEN, NO_MORE_BUFFERS); + } + + @Override + public OutputBufferInfo getInfo() + { + BufferState state = this.state.get(); + return new OutputBufferInfo( + "EXTERNAL", + state, + false, + state.canAddPages(), + exchangeSink.getSystemMemoryUsage(), + totalPagesAdded.get(), + totalRowsAdded.get(), + totalPagesAdded.get(), + ImmutableList.of()); + } + + @Override + public boolean isFinished() + { + return state.get() == FINISHED; + } + + @Override + public double getUtilization() + { + return 0; + } + + @Override + public boolean isOverutilized() + { + return false; + } + + @Override + public void addStateChangeListener(StateMachine.StateChangeListener stateChangeListener) + { + state.addStateChangeListener(stateChangeListener); + } + + @Override + public void setOutputBuffers(OutputBuffers newOutputBuffers) + { + requireNonNull(newOutputBuffers, "newOutputBuffers is null"); + + // ignore buffers added after query finishes, which can happen when a query is canceled + // also ignore old versions, which is normal + if (state.get().isTerminal() || outputBuffers.getVersion() >= newOutputBuffers.getVersion()) { + return; + } + + // no more buffers can be added but verify this is valid state change + outputBuffers.checkValidTransition(newOutputBuffers); + } + + @Override + public ListenableFuture get(OutputBuffers.OutputBufferId bufferId, long token, DataSize maxSize) + { + throw new UnsupportedOperationException(); + } + + @Override + public void acknowledge(OutputBuffers.OutputBufferId bufferId, long token) + { + throw new UnsupportedOperationException(); + } + + @Override + public void abort(OutputBuffers.OutputBufferId bufferId) + { + throw new UnsupportedOperationException(); + } + + @Override + public ListenableFuture isFull() + { + return asVoid(toListenableFuture(exchangeSink.isBlocked())); + } + + @Override + public void enqueue(List pages) + { + enqueue(0, pages); + } + + @Override + public void enqueue(int partition, List pages) + { + requireNonNull(pages, "pages is null"); + + // ignore pages after "no more pages" is set + // this can happen with a limit query + if (!state.get().canAddPages()) { + return; + } + + for (SerializedPage page : pages) { + // TODO: Avoid extra memory copy + Slice slice = Slices.allocate(page.getSizeInBytes() + SERIALIZED_PAGE_HEADER_SIZE); + writeSerializedPage(slice.getOutput(), page); + exchangeSink.add(partition, slice); + totalRowsAdded.addAndGet(page.getPositionCount()); + } + updateMemoryUsage(exchangeSink.getSystemMemoryUsage()); + totalPagesAdded.addAndGet(pages.size()); + } + + @Override + public void setNoMorePages() + { + if (state.compareAndSet(NO_MORE_BUFFERS, FLUSHING)) { + destroy(); + } + } + + @Override + public void destroy() + { + if (state.setIf(FINISHED, oldState -> !oldState.isTerminal())) { + try { + exchangeSink.finish(); + } + finally { + updateMemoryUsage(exchangeSink.getSystemMemoryUsage()); + } + } + } + + @Override + public void fail() + { + if (state.setIf(FAILED, oldState -> !oldState.isTerminal())) { + try { + exchangeSink.abort(); + } + finally { + updateMemoryUsage(0); + } + } + } + + @Override + public long getPeakMemoryUsage() + { + return peakMemoryUsage.get(); + } + + private void updateMemoryUsage(long bytes) + { + LocalMemoryContext context = getSystemMemoryContextOrNull(); + if (context != null) { + context.setBytes(bytes); + } + updatePeakMemoryUsage(bytes); + } + + private void updatePeakMemoryUsage(long bytes) + { + while (true) { + long currentValue = peakMemoryUsage.get(); + if (currentValue >= bytes) { + return; + } + if (peakMemoryUsage.compareAndSet(currentValue, bytes)) { + return; + } + } + } + + private LocalMemoryContext getSystemMemoryContextOrNull() + { + try { + return systemMemoryContextSupplier.get(); + } + catch (RuntimeException ignored) { + // This is possible with races, e.g., a task is created and then immediately aborted, + // so that the task context hasn't been created yet (as a result there's no memory context available). + return null; + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/buffer/LazyOutputBuffer.java b/core/trino-main/src/main/java/io/trino/execution/buffer/LazyOutputBuffer.java index c24fb48a7b86..d946940210b8 100644 --- a/core/trino-main/src/main/java/io/trino/execution/buffer/LazyOutputBuffer.java +++ b/core/trino-main/src/main/java/io/trino/execution/buffer/LazyOutputBuffer.java @@ -18,11 +18,15 @@ import com.google.common.util.concurrent.ListenableFuture; import io.airlift.concurrent.ExtendedSettableFuture; import io.airlift.units.DataSize; +import io.trino.exchange.ExchangeManagerRegistry; import io.trino.execution.StateMachine; import io.trino.execution.StateMachine.StateChangeListener; import io.trino.execution.TaskId; import io.trino.execution.buffer.OutputBuffers.OutputBufferId; import io.trino.memory.context.LocalMemoryContext; +import io.trino.spi.exchange.ExchangeManager; +import io.trino.spi.exchange.ExchangeSink; +import io.trino.spi.exchange.ExchangeSinkInstanceHandle; import javax.annotation.Nullable; import javax.annotation.concurrent.GuardedBy; @@ -54,6 +58,7 @@ public class LazyOutputBuffer private final Supplier systemMemoryContextSupplier; private final Executor executor; private final Runnable notifyStatusChanged; + private final ExchangeManagerRegistry exchangeManagerRegistry; // Note: this is a write once field, so an unsynchronized volatile read that returns a non-null value is safe, but if a null value is observed instead // a subsequent synchronized read is required to ensure the writing thread can complete any in-flight initialization @@ -73,9 +78,9 @@ public LazyOutputBuffer( DataSize maxBufferSize, DataSize maxBroadcastBufferSize, Supplier systemMemoryContextSupplier, - Runnable notifyStatusChanged) + Runnable notifyStatusChanged, + ExchangeManagerRegistry exchangeManagerRegistry) { - requireNonNull(taskId, "taskId is null"); this.taskInstanceId = requireNonNull(taskInstanceId, "taskInstanceId is null"); this.executor = requireNonNull(executor, "executor is null"); state = new StateMachine<>(taskId + "-buffer", executor, OPEN, TERMINAL_BUFFER_STATES); @@ -84,6 +89,7 @@ public LazyOutputBuffer( checkArgument(maxBufferSize.toBytes() > 0, "maxBufferSize must be at least 1"); this.systemMemoryContextSupplier = requireNonNull(systemMemoryContextSupplier, "systemMemoryContextSupplier is null"); this.notifyStatusChanged = requireNonNull(notifyStatusChanged, "notifyStatusChanged is null"); + this.exchangeManagerRegistry = requireNonNull(exchangeManagerRegistry, "exchangeManagerRegistry is null"); } @Override @@ -168,6 +174,15 @@ public void setOutputBuffers(OutputBuffers newOutputBuffers) case ARBITRARY: outputBuffer = new ArbitraryOutputBuffer(taskInstanceId, state, maxBufferSize, systemMemoryContextSupplier, executor); break; + case EXTERNAL: + ExchangeSinkInstanceHandle exchangeSinkInstanceHandle = newOutputBuffers.getExchangeSinkInstanceHandle() + .orElseThrow(() -> new IllegalArgumentException("exchange sink handle is expected to be present for buffer type EXTERNAL")); + ExchangeManager exchangeManager = exchangeManagerRegistry.getExchangeManager(); + ExchangeSink exchangeSink = exchangeManager.createSink(exchangeSinkInstanceHandle); + outputBuffer = new ExternalExchangeOutputBuffer(state, newOutputBuffers, exchangeSink, systemMemoryContextSupplier); + break; + default: + throw new IllegalArgumentException("Unexpected output buffer type: " + newOutputBuffers.getType()); } // process pending aborts and reads outside of synchronized lock diff --git a/core/trino-main/src/main/java/io/trino/execution/buffer/OutputBuffers.java b/core/trino-main/src/main/java/io/trino/execution/buffer/OutputBuffers.java index a4c42fd922bc..fb4582cd8d44 100644 --- a/core/trino-main/src/main/java/io/trino/execution/buffer/OutputBuffers.java +++ b/core/trino-main/src/main/java/io/trino/execution/buffer/OutputBuffers.java @@ -17,18 +17,21 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonValue; import com.google.common.collect.ImmutableMap; +import io.trino.spi.exchange.ExchangeSinkInstanceHandle; import io.trino.sql.planner.PartitioningHandle; import java.util.HashMap; import java.util.Map; import java.util.Map.Entry; import java.util.Objects; +import java.util.Optional; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static io.trino.execution.buffer.OutputBuffers.BufferType.ARBITRARY; import static io.trino.execution.buffer.OutputBuffers.BufferType.BROADCAST; +import static io.trino.execution.buffer.OutputBuffers.BufferType.EXTERNAL; import static io.trino.execution.buffer.OutputBuffers.BufferType.PARTITIONED; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION; @@ -41,7 +44,7 @@ public final class OutputBuffers public static OutputBuffers createInitialEmptyOutputBuffers(BufferType type) { - return new OutputBuffers(type, 0, false, ImmutableMap.of()); + return new OutputBuffers(type, 0, false, ImmutableMap.of(), Optional.empty()); } public static OutputBuffers createInitialEmptyOutputBuffers(PartitioningHandle partitioningHandle) @@ -56,7 +59,12 @@ else if (partitioningHandle.equals(FIXED_ARBITRARY_DISTRIBUTION)) { else { type = PARTITIONED; } - return new OutputBuffers(type, 0, false, ImmutableMap.of()); + return new OutputBuffers(type, 0, false, ImmutableMap.of(), Optional.empty()); + } + + public static OutputBuffers createExternalExchangeOutputBuffers(ExchangeSinkInstanceHandle exchangeSinkInstanceHandle) + { + return new OutputBuffers(EXTERNAL, 0, true, ImmutableMap.of(), Optional.of(exchangeSinkInstanceHandle)); } public enum BufferType @@ -64,12 +72,14 @@ public enum BufferType PARTITIONED, BROADCAST, ARBITRARY, + EXTERNAL, } private final BufferType type; private final long version; private final boolean noMoreBufferIds; private final Map buffers; + private final Optional exchangeSinkInstanceHandle; // Visible only for Jackson... Use the "with" methods instead @JsonCreator @@ -77,12 +87,14 @@ public OutputBuffers( @JsonProperty("type") BufferType type, @JsonProperty("version") long version, @JsonProperty("noMoreBufferIds") boolean noMoreBufferIds, - @JsonProperty("buffers") Map buffers) + @JsonProperty("buffers") Map buffers, + @JsonProperty("exchangeSinkInstanceHandle") Optional exchangeSinkInstanceHandle) { this.type = type; this.version = version; this.buffers = ImmutableMap.copyOf(requireNonNull(buffers, "buffers is null")); this.noMoreBufferIds = noMoreBufferIds; + this.exchangeSinkInstanceHandle = requireNonNull(exchangeSinkInstanceHandle, "exchangeSinkInstanceHandle is null"); } @JsonProperty @@ -109,6 +121,12 @@ public Map getBuffers() return buffers; } + @JsonProperty + public Optional getExchangeSinkInstanceHandle() + { + return exchangeSinkInstanceHandle; + } + public void checkValidTransition(OutputBuffers newOutputBuffers) { requireNonNull(newOutputBuffers, "newOutputBuffers is null"); @@ -186,7 +204,8 @@ public OutputBuffers withBuffer(OutputBufferId bufferId, int partition) ImmutableMap.builder() .putAll(buffers) .put(bufferId, partition) - .build()); + .build(), + exchangeSinkInstanceHandle); } public OutputBuffers withBuffers(Map buffers) @@ -218,7 +237,7 @@ public OutputBuffers withBuffers(Map buffers) // add the existing buffers newBuffers.putAll(this.buffers); - return new OutputBuffers(type, version + 1, false, newBuffers); + return new OutputBuffers(type, version + 1, false, newBuffers, exchangeSinkInstanceHandle); } public OutputBuffers withNoMoreBufferIds() @@ -227,7 +246,7 @@ public OutputBuffers withNoMoreBufferIds() return this; } - return new OutputBuffers(type, version + 1, true, buffers); + return new OutputBuffers(type, version + 1, true, buffers, exchangeSinkInstanceHandle); } private void checkHasBuffer(OutputBufferId bufferId, int partition) diff --git a/core/trino-main/src/main/java/io/trino/execution/buffer/PagesSerdeUtil.java b/core/trino-main/src/main/java/io/trino/execution/buffer/PagesSerdeUtil.java index 1076211cc9e2..da142753b759 100644 --- a/core/trino-main/src/main/java/io/trino/execution/buffer/PagesSerdeUtil.java +++ b/core/trino-main/src/main/java/io/trino/execution/buffer/PagesSerdeUtil.java @@ -41,6 +41,13 @@ private PagesSerdeUtil() {} * @implNote It's not just 0, so that hypothetical zero-ed out data is not treated as valid payload with no checksum. */ public static final long NO_CHECKSUM = 0x0123456789abcdefL; + public static final int SERIALIZED_PAGE_HEADER_SIZE = /*positionCount*/ Integer.BYTES + + // pageCodecMarkers + Byte.BYTES + + // uncompressedSizeInBytes + Integer.BYTES + + // sizeInBytes + Integer.BYTES; static void writeRawPage(Page page, SliceOutput output, BlockEncodingSerde serde) { @@ -81,7 +88,7 @@ private static void updateChecksum(XxHash64 hash, SerializedPage page) hash.update(page.getSlice()); } - private static SerializedPage readSerializedPage(SliceInput sliceInput) + public static SerializedPage readSerializedPage(SliceInput sliceInput) { int positionCount = sliceInput.readInt(); PageCodecMarker.MarkerSet markers = PageCodecMarker.MarkerSet.fromByteValue(sliceInput.readByte()); diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/BucketNodeMap.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/BucketNodeMap.java index 0bf6d6a880fe..c148e41b5cb0 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/BucketNodeMap.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/BucketNodeMap.java @@ -32,6 +32,8 @@ public BucketNodeMap(ToIntFunction splitToBucket) public abstract int getBucketCount(); + public abstract int getNodeCount(); + public abstract Optional getAssignedNode(int bucketedId); public abstract void assignBucketToNode(int bucketedId, InternalNode node); @@ -42,4 +44,9 @@ public final Optional getAssignedNode(Split split) { return getAssignedNode(splitToBucket.applyAsInt(split)); } + + public final int getBucket(Split split) + { + return splitToBucket.applyAsInt(split); + } } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java new file mode 100644 index 000000000000..8a3b8c6fe39d --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java @@ -0,0 +1,530 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import com.google.common.base.VerifyException; +import com.google.common.collect.ArrayListMultimap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Multimap; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; +import io.airlift.log.Logger; +import io.trino.Session; +import io.trino.execution.ExecutionFailureInfo; +import io.trino.execution.Lifespan; +import io.trino.execution.RemoteTask; +import io.trino.execution.SqlStage; +import io.trino.execution.StageId; +import io.trino.execution.TaskId; +import io.trino.execution.TaskState; +import io.trino.execution.TaskStatus; +import io.trino.execution.buffer.OutputBuffers; +import io.trino.failuredetector.FailureDetector; +import io.trino.metadata.InternalNode; +import io.trino.metadata.Split; +import io.trino.spi.ErrorCode; +import io.trino.spi.TrinoException; +import io.trino.spi.exchange.Exchange; +import io.trino.spi.exchange.ExchangeSinkHandle; +import io.trino.spi.exchange.ExchangeSinkInstanceHandle; +import io.trino.spi.exchange.ExchangeSourceHandle; +import io.trino.split.RemoteSplit; +import io.trino.split.RemoteSplit.ExternalExchangeInput; +import io.trino.sql.planner.plan.PlanFragmentId; +import io.trino.sql.planner.plan.PlanNodeId; + +import javax.annotation.concurrent.GuardedBy; + +import java.util.ArrayDeque; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Queue; +import java.util.Set; +import java.util.function.Function; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Throwables.propagateIfPossible; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableListMultimap.flatteningToImmutableListMultimap; +import static com.google.common.collect.ImmutableListMultimap.toImmutableListMultimap; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.util.concurrent.Futures.allAsList; +import static com.google.common.util.concurrent.Futures.immediateVoidFuture; +import static com.google.common.util.concurrent.Futures.nonCancellationPropagating; +import static io.airlift.concurrent.MoreFutures.asVoid; +import static io.airlift.concurrent.MoreFutures.getFutureValue; +import static io.airlift.concurrent.MoreFutures.toListenableFuture; +import static io.trino.execution.ExecutionFailureInfo.rewriteTransportFailure; +import static io.trino.execution.buffer.OutputBuffers.BufferType.PARTITIONED; +import static io.trino.execution.buffer.OutputBuffers.createExternalExchangeOutputBuffers; +import static io.trino.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers; +import static io.trino.operator.ExchangeOperator.REMOTE_CONNECTOR_ID; +import static io.trino.spi.ErrorType.USER_ERROR; +import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; +import static io.trino.util.Failures.toFailure; +import static java.util.Objects.requireNonNull; + +public class FaultTolerantStageScheduler +{ + private static final Logger log = Logger.get(FaultTolerantStageScheduler.class); + + private final Session session; + private final SqlStage stage; + private final FailureDetector failureDetector; + private final TaskSourceFactory taskSourceFactory; + private final NodeAllocator nodeAllocator; + + private final TaskLifecycleListener taskLifecycleListener; + // empty when the results are consumed via a direct exchange + private final Optional sinkExchange; + private final Optional sinkBucketToPartitionMap; + + private final Map sourceExchanges; + private final Optional sourceBucketToPartitionMap; + private final Optional sourceBucketNodeMap; + + @GuardedBy("this") + private ListenableFuture blocked = immediateVoidFuture(); + + @GuardedBy("this") + private ListenableFuture acquireNodeFuture; + @GuardedBy("this") + private SettableFuture taskFinishedFuture; + + @GuardedBy("this") + private TaskSource taskSource; + @GuardedBy("this") + private final Map partitionToTaskDescriptorMap = new HashMap<>(); + @GuardedBy("this") + private final Map partitionToExchangeSinkHandleMap = new HashMap<>(); + @GuardedBy("this") + private final Multimap partitionToRemoteTaskMap = ArrayListMultimap.create(); + @GuardedBy("this") + private final Map runningTasks = new HashMap<>(); + @GuardedBy("this") + private final Map runningNodes = new HashMap<>(); + @GuardedBy("this") + private final Queue queuedPartitions = new ArrayDeque<>(); + @GuardedBy("this") + private final Set finishedPartitions = new HashSet<>(); + @GuardedBy("this") + private int remainingRetryAttempts; + + @GuardedBy("this") + private Throwable failure; + @GuardedBy("this") + private boolean closed; + + public FaultTolerantStageScheduler( + Session session, + SqlStage stage, + FailureDetector failureDetector, + TaskSourceFactory taskSourceFactory, + NodeAllocator nodeAllocator, + TaskLifecycleListener taskLifecycleListener, + Optional sinkExchange, + Optional sinkBucketToPartitionMap, + Map sourceExchanges, + Optional sourceBucketToPartitionMap, + Optional sourceBucketNodeMap, + int retryAttempts) + { + checkArgument(!stage.getFragment().getStageExecutionDescriptor().isStageGroupedExecution(), "grouped execution is expected to be disabled"); + + this.session = requireNonNull(session, "session is null"); + this.stage = requireNonNull(stage, "stage is null"); + this.failureDetector = requireNonNull(failureDetector, "failureDetector is null"); + this.taskSourceFactory = requireNonNull(taskSourceFactory, "taskSourceFactory is null"); + this.nodeAllocator = requireNonNull(nodeAllocator, "nodeAllocator is null"); + this.taskLifecycleListener = requireNonNull(taskLifecycleListener, "taskLifecycleListener is null"); + this.sinkExchange = requireNonNull(sinkExchange, "sinkExchange is null"); + this.sinkBucketToPartitionMap = requireNonNull(sinkBucketToPartitionMap, "sinkBucketToPartitionMap is null"); + this.sourceExchanges = ImmutableMap.copyOf(requireNonNull(sourceExchanges, "sourceExchanges is null")); + this.sourceBucketToPartitionMap = requireNonNull(sourceBucketToPartitionMap, "sourceBucketToPartitionMap is null"); + this.sourceBucketNodeMap = requireNonNull(sourceBucketNodeMap, "sourceBucketNodeMap is null"); + this.remainingRetryAttempts = retryAttempts; + } + + public StageId getStageId() + { + return stage.getStageId(); + } + + public synchronized ListenableFuture isBlocked() + { + return nonCancellationPropagating(blocked); + } + + public synchronized void schedule() + throws Exception + { + if (failure != null) { + propagateIfPossible(failure, Exception.class); + throw new RuntimeException(failure); + } + + if (closed) { + return; + } + + if (isFinished()) { + return; + } + + if (!blocked.isDone()) { + return; + } + + if (taskSource == null) { + Map>> exchangeSourceHandleMap = sourceExchanges.entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, entry -> toListenableFuture(entry.getValue().getSourceHandles()))); + + List>> blockedFutures = exchangeSourceHandleMap.values().stream() + .filter(future -> !future.isDone()) + .collect(toImmutableList()); + + if (!blockedFutures.isEmpty()) { + blocked = asVoid(allAsList(blockedFutures)); + return; + } + + Multimap exchangeSources = exchangeSourceHandleMap.entrySet().stream() + .collect(flatteningToImmutableListMultimap(Map.Entry::getKey, entry -> getFutureValue(entry.getValue()).stream())); + + taskSource = taskSourceFactory.create( + session, + stage.getFragment(), + sourceExchanges, + exchangeSources, + stage::recordGetSplitTime, + sourceBucketToPartitionMap, + sourceBucketNodeMap); + } + + while (!queuedPartitions.isEmpty() || !taskSource.isFinished()) { + while (queuedPartitions.isEmpty() && !taskSource.isFinished()) { + List tasks = taskSource.getMoreTasks(); + for (TaskDescriptor task : tasks) { + queuedPartitions.add(task.getPartitionId()); + partitionToTaskDescriptorMap.put(task.getPartitionId(), task); + sinkExchange.ifPresent(exchange -> { + ExchangeSinkHandle exchangeSinkHandle = exchange.addSink(task.getPartitionId()); + partitionToExchangeSinkHandleMap.put(task.getPartitionId(), exchangeSinkHandle); + }); + } + if (taskSource.isFinished()) { + sinkExchange.ifPresent(Exchange::noMoreSinks); + } + } + + if (queuedPartitions.isEmpty()) { + break; + } + + int partition = queuedPartitions.peek(); + TaskDescriptor taskDescriptor = requireNonNull(partitionToTaskDescriptorMap.get(partition), () -> "task descriptor missing for partition: %s" + partition); + + if (acquireNodeFuture == null) { + acquireNodeFuture = nodeAllocator.acquire(taskDescriptor.getNodeRequirements()); + } + if (!acquireNodeFuture.isDone()) { + blocked = asVoid(acquireNodeFuture); + return; + } + InternalNode node = getFutureValue(acquireNodeFuture); + acquireNodeFuture = null; + + queuedPartitions.poll(); + + Multimap tableScanSplits = taskDescriptor.getSplits(); + Multimap remoteSplits = createRemoteSplits(taskDescriptor.getExchangeSourceHandles()); + + Multimap taskSplits = ImmutableListMultimap.builder() + .putAll(tableScanSplits) + .putAll(remoteSplits) + .build(); + + int attemptId = getNextAttemptIdForPartition(partition); + + OutputBuffers outputBuffers; + Optional exchangeSinkInstanceHandle; + if (sinkExchange.isPresent()) { + ExchangeSinkHandle sinkHandle = partitionToExchangeSinkHandleMap.get(partition); + exchangeSinkInstanceHandle = Optional.of(sinkExchange.get().instantiateSink(sinkHandle, attemptId)); + outputBuffers = createExternalExchangeOutputBuffers(exchangeSinkInstanceHandle.get()); + } + else { + exchangeSinkInstanceHandle = Optional.empty(); + // stage will be consumed by the coordinator using direct exchange + outputBuffers = createInitialEmptyOutputBuffers(PARTITIONED) + .withBuffer(new OutputBuffers.OutputBufferId(0), 0) + .withNoMoreBufferIds(); + } + + Set allSourcePlanNodeIds = new HashSet<>(stage.getFragment().getPartitionedSources()); + stage.getFragment().getRemoteSourceNodes().forEach(planNode -> allSourcePlanNodeIds.add(planNode.getId())); + + RemoteTask task = stage.createTask( + node, + partition, + attemptId, + sinkBucketToPartitionMap, + outputBuffers, + taskSplits, + allSourcePlanNodeIds.stream() + .collect(toImmutableListMultimap(Function.identity(), planNodeId -> Lifespan.taskWide())), + allSourcePlanNodeIds).orElseThrow(() -> new VerifyException("stage execution is expected to be active")); + + partitionToRemoteTaskMap.put(partition, task); + runningTasks.put(task.getTaskId(), task); + runningNodes.put(task.getTaskId(), node); + + if (taskFinishedFuture == null) { + taskFinishedFuture = SettableFuture.create(); + } + + taskLifecycleListener.taskCreated(stage.getFragment().getId(), task); + + task.addStateChangeListener(taskStatus -> updateTaskStatus(taskStatus, exchangeSinkInstanceHandle)); + task.start(); + } + + if (taskFinishedFuture != null && !taskFinishedFuture.isDone()) { + blocked = taskFinishedFuture; + } + } + + public synchronized boolean isFinished() + { + return failure == null && + taskSource != null && + taskSource.isFinished() && + queuedPartitions.isEmpty() && + finishedPartitions.containsAll(partitionToTaskDescriptorMap.keySet()); + } + + public void cancel() + { + close(false); + } + + public void abort() + { + close(true); + } + + private void fail(Throwable t) + { + synchronized (this) { + if (failure == null) { + failure = t; + } + } + close(true); + } + + private void close(boolean abort) + { + boolean closed; + synchronized (this) { + closed = this.closed; + this.closed = true; + } + if (!closed) { + cancelRunningTasks(abort); + cancelBlockedFuture(); + releaseAcquiredNode(); + closeTaskSource(); + closeSinkExchange(); + } + } + + private void cancelRunningTasks(boolean abort) + { + List tasks; + synchronized (this) { + tasks = ImmutableList.copyOf(runningTasks.values()); + } + if (abort) { + tasks.forEach(RemoteTask::abort); + } + else { + tasks.forEach(RemoteTask::cancel); + } + } + + private void cancelBlockedFuture() + { + verify(!Thread.holdsLock(this)); + ListenableFuture future; + synchronized (this) { + future = blocked; + } + if (future != null && !future.isDone()) { + future.cancel(true); + } + } + + private void releaseAcquiredNode() + { + verify(!Thread.holdsLock(this)); + ListenableFuture future; + synchronized (this) { + future = acquireNodeFuture; + acquireNodeFuture = null; + } + if (future != null) { + future.cancel(true); + if (future.isDone() && !future.isCancelled()) { + nodeAllocator.release(getFutureValue(future)); + } + } + } + + private void closeTaskSource() + { + TaskSource taskSource; + synchronized (this) { + taskSource = this.taskSource; + } + if (taskSource != null) { + try { + taskSource.close(); + } + catch (RuntimeException e) { + log.warn(e, "Error closing task source for stage: %s", stage.getStageId()); + } + } + } + + private void closeSinkExchange() + { + try { + sinkExchange.ifPresent(Exchange::close); + } + catch (RuntimeException e) { + log.warn(e, "Error closing sink exchange for stage: %s", stage.getStageId()); + } + } + + public synchronized void reportTaskFailure(TaskId taskId, Throwable failureCause) + { + RemoteTask task = runningTasks.get(taskId); + if (task != null) { + task.fail(failureCause); + } + } + + private int getNextAttemptIdForPartition(int partition) + { + int latestAttemptId = partitionToRemoteTaskMap.get(partition).stream() + .mapToInt(task -> task.getTaskId().getAttemptId()) + .max() + .orElse(-1); + return latestAttemptId + 1; + } + + private static Multimap createRemoteSplits(Multimap exchangeSourceHandles) + { + ImmutableListMultimap.Builder result = ImmutableListMultimap.builder(); + for (PlanNodeId planNodeId : exchangeSourceHandles.keySet()) { + result.put(planNodeId, new Split(REMOTE_CONNECTOR_ID, new RemoteSplit(new ExternalExchangeInput(ImmutableList.copyOf(exchangeSourceHandles.get(planNodeId)))), Lifespan.taskWide())); + } + return result.build(); + } + + private void updateTaskStatus(TaskStatus taskStatus, Optional exchangeSinkInstanceHandle) + { + TaskState state = taskStatus.getState(); + if (!state.isDone()) { + return; + } + + try { + RuntimeException failure = null; + SettableFuture future; + synchronized (this) { + TaskId taskId = taskStatus.getTaskId(); + + runningTasks.remove(taskId); + future = taskFinishedFuture; + if (!runningTasks.isEmpty()) { + taskFinishedFuture = SettableFuture.create(); + } + else { + taskFinishedFuture = null; + } + + InternalNode node = requireNonNull(runningNodes.remove(taskId), () -> "node not found for task id: " + taskId); + nodeAllocator.release(node); + + int partitionId = taskId.getPartitionId(); + + if (!finishedPartitions.contains(partitionId) && !closed) { + switch (state) { + case FINISHED: + finishedPartitions.add(partitionId); + if (sinkExchange.isPresent()) { + checkArgument(exchangeSinkInstanceHandle.isPresent(), "exchangeSinkInstanceHandle is expected to be present"); + sinkExchange.get().sinkFinished(exchangeSinkInstanceHandle.get()); + } + partitionToRemoteTaskMap.get(partitionId).forEach(RemoteTask::abort); + break; + case CANCELED: + log.info("Task cancelled: %s", taskId); + break; + case ABORTED: + log.info("Task aborted: %s", taskId); + break; + case FAILED: + ExecutionFailureInfo failureInfo = taskStatus.getFailures().stream() + .findFirst() + .map(f -> rewriteTransportFailure(failureDetector, f)) + .orElse(toFailure(new TrinoException(GENERIC_INTERNAL_ERROR, "A task failed for an unknown reason"))); + log.warn(failureInfo.toException(), "Task failed: %s", taskId); + ErrorCode errorCode = failureInfo.getErrorCode(); + if (remainingRetryAttempts > 0 && (errorCode == null || errorCode.getType() != USER_ERROR)) { + remainingRetryAttempts--; + // schedule failed tasks first + queuedPartitions.add(partitionId); + log.info("Retrying partition %s for stage %s", partitionId, stage.getStageId()); + } + else { + failure = failureInfo.toException(); + } + break; + default: + throw new IllegalArgumentException("Unexpected task state: " + state); + } + } + } + if (failure != null) { + // must be called outside the lock + fail(failure); + } + if (future != null && !future.isDone()) { + future.set(null); + } + } + catch (Throwable t) { + fail(t); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedBucketNodeMap.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedBucketNodeMap.java index 9322c3138390..54dc29b3e9c4 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedBucketNodeMap.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedBucketNodeMap.java @@ -14,6 +14,7 @@ package io.trino.execution.scheduler; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import io.trino.metadata.InternalNode; import io.trino.metadata.Split; @@ -47,6 +48,12 @@ public int getBucketCount() return bucketToNode.size(); } + @Override + public int getNodeCount() + { + return ImmutableSet.copyOf(bucketToNode).size(); + } + @Override public void assignBucketToNode(int bucketedId, InternalNode node) { diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedCountNodeAllocator.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedCountNodeAllocator.java new file mode 100644 index 000000000000..01bc5f5892de --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedCountNodeAllocator.java @@ -0,0 +1,205 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; +import io.trino.Session; +import io.trino.connector.CatalogName; +import io.trino.metadata.InternalNode; +import io.trino.spi.TrinoException; + +import javax.annotation.concurrent.GuardedBy; + +import java.util.HashMap; +import java.util.IdentityHashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.util.concurrent.Futures.immediateFailedFuture; +import static com.google.common.util.concurrent.Futures.immediateFuture; +import static io.trino.spi.StandardErrorCode.NO_NODES_AVAILABLE; +import static java.util.Comparator.comparing; +import static java.util.Objects.requireNonNull; + +public class FixedCountNodeAllocator + implements NodeAllocator +{ + private final NodeScheduler nodeScheduler; + + private final Session session; + private final int maximumAllocationsPerNode; + + @GuardedBy("this") + private final Map, NodeSelector> nodeSelectorCache = new HashMap<>(); + + @GuardedBy("this") + private final Map allocationCountMap = new HashMap<>(); + + @GuardedBy("this") + private final LinkedList pendingAcquires = new LinkedList<>(); + + public FixedCountNodeAllocator( + NodeScheduler nodeScheduler, + Session session, + int maximumAllocationsPerNode) + { + this.nodeScheduler = requireNonNull(nodeScheduler, "nodeScheduler is null"); + this.session = session; + this.maximumAllocationsPerNode = maximumAllocationsPerNode; + } + + @Override + public synchronized ListenableFuture acquire(NodeRequirements requirements) + { + try { + Optional node = tryAcquireNode(requirements); + if (node.isPresent()) { + return immediateFuture(node.get()); + } + } + catch (RuntimeException e) { + return immediateFailedFuture(e); + } + + SettableFuture future = SettableFuture.create(); + PendingAcquire pendingAcquire = new PendingAcquire(requirements, future); + pendingAcquires.add(pendingAcquire); + + return future; + } + + @Override + public void release(InternalNode node) + { + releaseNodeInternal(node); + processPendingAcquires(); + } + + @Override + public void updateNodes() + { + processPendingAcquires(); + } + + private synchronized Optional tryAcquireNode(NodeRequirements requirements) + { + NodeSelector nodeSelector = nodeSelectorCache.computeIfAbsent(requirements.getCatalogName(), catalogName -> nodeScheduler.createNodeSelector(session, catalogName)); + + List nodes = nodeSelector.allNodes(); + if (nodes.isEmpty()) { + throw new TrinoException(NO_NODES_AVAILABLE, "No nodes available to run query"); + } + + List nodesMatchingRequirements = nodes.stream() + .filter(node -> requirements.getAddresses().isEmpty() || requirements.getAddresses().get().contains(node.getHostAndPort())) + .collect(toImmutableList()); + + if (nodesMatchingRequirements.isEmpty()) { + throw new TrinoException(NO_NODES_AVAILABLE, "No nodes available to run query"); + } + + Optional selectedNode = nodesMatchingRequirements.stream() + .filter(node -> allocationCountMap.getOrDefault(node, 0) < maximumAllocationsPerNode) + .min(comparing(node -> allocationCountMap.getOrDefault(node, 0))); + + if (selectedNode.isEmpty()) { + return Optional.empty(); + } + + allocationCountMap.compute(selectedNode.get(), (key, value) -> value == null ? 1 : value + 1); + return selectedNode; + } + + private synchronized void releaseNodeInternal(InternalNode node) + { + int allocationCount = allocationCountMap.compute(node, (key, value) -> value == null ? 0 : value - 1); + checkState(allocationCount >= 0, "allocation count for node %s is expected to be greater than or equal to zero: %s", node, allocationCount); + } + + private void processPendingAcquires() + { + verify(!Thread.holdsLock(this)); + + Map assignedNodes = new IdentityHashMap<>(); + Map failures = new IdentityHashMap<>(); + synchronized (this) { + Iterator iterator = pendingAcquires.iterator(); + while (iterator.hasNext()) { + PendingAcquire pendingAcquire = iterator.next(); + if (pendingAcquire.getFuture().isCancelled()) { + iterator.remove(); + continue; + } + try { + Optional node = tryAcquireNode(pendingAcquire.getNodeRequirements()); + if (node.isPresent()) { + iterator.remove(); + assignedNodes.put(pendingAcquire, node.get()); + } + } + catch (RuntimeException e) { + iterator.remove(); + failures.put(pendingAcquire, e); + } + } + } + + assignedNodes.forEach((pendingAcquire, node) -> { + SettableFuture future = pendingAcquire.getFuture(); + future.set(node); + if (future.isCancelled()) { + releaseNodeInternal(node); + } + }); + + failures.forEach((pendingAcquire, failure) -> { + SettableFuture future = pendingAcquire.getFuture(); + future.setException(failure); + }); + } + + @Override + public synchronized void close() + { + } + + private static class PendingAcquire + { + private final NodeRequirements nodeRequirements; + private final SettableFuture future; + + private PendingAcquire(NodeRequirements nodeRequirements, SettableFuture future) + { + this.nodeRequirements = requireNonNull(nodeRequirements, "nodeRequirements is null"); + this.future = requireNonNull(future, "future is null"); + } + + public NodeRequirements getNodeRequirements() + { + return nodeRequirements; + } + + public SettableFuture getFuture() + { + return future; + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeAllocator.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeAllocator.java new file mode 100644 index 000000000000..778c059982e6 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeAllocator.java @@ -0,0 +1,32 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import com.google.common.util.concurrent.ListenableFuture; +import io.trino.metadata.InternalNode; + +import java.io.Closeable; + +public interface NodeAllocator + extends Closeable +{ + ListenableFuture acquire(NodeRequirements requirements); + + void release(InternalNode node); + + void updateNodes(); + + @Override + void close(); +} diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeRequirements.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeRequirements.java new file mode 100644 index 000000000000..035dd137a948 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeRequirements.java @@ -0,0 +1,75 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import com.google.common.collect.ImmutableSet; +import io.trino.connector.CatalogName; +import io.trino.spi.HostAddress; + +import java.util.Objects; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public class NodeRequirements +{ + private final Optional catalogName; + private final Optional> addresses; + + public NodeRequirements(Optional catalogName, Optional> addresses) + { + this.catalogName = requireNonNull(catalogName, "catalogName is null"); + this.addresses = requireNonNull(addresses, "addresses is null").map(ImmutableSet::copyOf); + } + + public Optional getCatalogName() + { + return catalogName; + } + + public Optional> getAddresses() + { + return addresses; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + NodeRequirements that = (NodeRequirements) o; + return Objects.equals(catalogName, that.catalogName) && Objects.equals(addresses, that.addresses); + } + + @Override + public int hashCode() + { + return Objects.hash(catalogName, addresses); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("catalogName", catalogName) + .add("addresses", addresses) + .toString(); + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedStageExecution.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedStageExecution.java index d9812b0a2233..3344393e957f 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedStageExecution.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedStageExecution.java @@ -38,6 +38,7 @@ import io.trino.metadata.Split; import io.trino.spi.TrinoException; import io.trino.split.RemoteSplit; +import io.trino.split.RemoteSplit.DirectExchangeInput; import io.trino.sql.planner.PlanFragment; import io.trino.sql.planner.plan.PlanFragmentId; import io.trino.sql.planner.plan.PlanNodeId; @@ -65,6 +66,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.airlift.http.client.HttpUriBuilder.uriBuilderFrom; +import static io.trino.execution.ExecutionFailureInfo.rewriteTransportFailure; import static io.trino.execution.scheduler.PipelinedStageExecution.State.ABORTED; import static io.trino.execution.scheduler.PipelinedStageExecution.State.CANCELED; import static io.trino.execution.scheduler.PipelinedStageExecution.State.FAILED; @@ -75,10 +77,8 @@ import static io.trino.execution.scheduler.PipelinedStageExecution.State.SCHEDULED; import static io.trino.execution.scheduler.PipelinedStageExecution.State.SCHEDULING; import static io.trino.execution.scheduler.PipelinedStageExecution.State.SCHEDULING_SPLITS; -import static io.trino.failuredetector.FailureDetector.State.GONE; import static io.trino.operator.ExchangeOperator.REMOTE_CONNECTOR_ID; import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; -import static io.trino.spi.StandardErrorCode.REMOTE_HOST_GONE; import static java.util.Objects.requireNonNull; public class PipelinedStageExecution @@ -354,7 +354,7 @@ private synchronized void updateTaskStatus(TaskStatus taskStatus) case FAILED: RuntimeException failure = taskStatus.getFailures().stream() .findFirst() - .map(this::rewriteTransportFailure) + .map(f -> rewriteTransportFailure(failureDetector, f)) .map(ExecutionFailureInfo::toException) .orElse(new TrinoException(GENERIC_INTERNAL_ERROR, "A task failed for an unknown reason")); fail(failure); @@ -407,23 +407,6 @@ private synchronized void updateCompletedDriverGroups(TaskStatus taskStatus) completedDriverGroups.addAll(newlyCompletedDriverGroups); } - private ExecutionFailureInfo rewriteTransportFailure(ExecutionFailureInfo executionFailureInfo) - { - if (executionFailureInfo.getRemoteHost() == null || failureDetector.getState(executionFailureInfo.getRemoteHost()) != GONE) { - return executionFailureInfo; - } - - return new ExecutionFailureInfo( - executionFailureInfo.getType(), - executionFailureInfo.getMessage(), - executionFailureInfo.getCause(), - executionFailureInfo.getSuppressed(), - executionFailureInfo.getStack(), - executionFailureInfo.getErrorLocation(), - REMOTE_HOST_GONE.toErrorCode(), - executionFailureInfo.getRemoteHost()); - } - public TaskLifecycleListener getTaskLifecycleListener() { return new TaskLifecycleListener() @@ -517,7 +500,7 @@ private static Split createExchangeSplit(RemoteTask sourceTask, RemoteTask desti // Fetch the results from the buffer assigned to the task based on id URI exchangeLocation = sourceTask.getTaskStatus().getSelf(); URI splitLocation = uriBuilderFrom(exchangeLocation).appendPath("results").appendPath(String.valueOf(destinationTask.getTaskId().getPartitionId())).build(); - return new Split(REMOTE_CONNECTOR_ID, new RemoteSplit(sourceTask.getTaskId(), splitLocation), Lifespan.taskWide()); + return new Split(REMOTE_CONNECTOR_ID, new RemoteSplit(new DirectExchangeInput(sourceTask.getTaskId(), splitLocation)), Lifespan.taskWide()); } public enum State diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java index a0742db9f03d..61e6351e8081 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMultimap; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Sets; import com.google.common.graph.Traverser; import com.google.common.primitives.Ints; @@ -28,6 +29,7 @@ import io.airlift.units.Duration; import io.trino.Session; import io.trino.connector.CatalogName; +import io.trino.exchange.ExchangeManagerRegistry; import io.trino.execution.BasicStageStats; import io.trino.execution.ExecutionFailureInfo; import io.trino.execution.NodeTaskMap; @@ -57,6 +59,9 @@ import io.trino.spi.QueryId; import io.trino.spi.TrinoException; import io.trino.spi.connector.ConnectorPartitionHandle; +import io.trino.spi.exchange.Exchange; +import io.trino.spi.exchange.ExchangeContext; +import io.trino.spi.exchange.ExchangeManager; import io.trino.split.SplitSource; import io.trino.sql.planner.NodePartitionMap; import io.trino.sql.planner.NodePartitioningManager; @@ -83,16 +88,19 @@ import java.util.Map.Entry; import java.util.Optional; import java.util.Set; +import java.util.concurrent.CancellationException; import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import java.util.function.Predicate; import java.util.function.Supplier; +import java.util.stream.IntStream; import java.util.stream.Stream; import static com.google.common.base.Preconditions.checkArgument; @@ -104,12 +112,14 @@ import static com.google.common.collect.Iterables.getFirst; import static com.google.common.collect.Iterables.getLast; import static com.google.common.collect.Iterables.getOnlyElement; +import static com.google.common.collect.Lists.reverse; import static com.google.common.collect.Sets.newConcurrentHashSet; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static io.airlift.concurrent.MoreFutures.tryGetFutureValue; import static io.airlift.concurrent.MoreFutures.whenAnyComplete; import static io.airlift.http.client.HttpUriBuilder.uriBuilderFrom; import static io.trino.SystemSessionProperties.getConcurrentLifespansPerNode; +import static io.trino.SystemSessionProperties.getHashPartitionCount; import static io.trino.SystemSessionProperties.getRetryAttempts; import static io.trino.SystemSessionProperties.getRetryInitialDelay; import static io.trino.SystemSessionProperties.getRetryMaxDelay; @@ -117,6 +127,7 @@ import static io.trino.SystemSessionProperties.getWriterMinSize; import static io.trino.connector.CatalogName.isInternalSystemConnector; import static io.trino.execution.BasicStageStats.aggregateBasicStageStats; +import static io.trino.execution.QueryState.FINISHING; import static io.trino.execution.SqlStage.createSqlStage; import static io.trino.execution.scheduler.PipelinedStageExecution.State.ABORTED; import static io.trino.execution.scheduler.PipelinedStageExecution.State.CANCELED; @@ -135,6 +146,7 @@ import static io.trino.spi.StandardErrorCode.REMOTE_TASK_FAILED; import static io.trino.spi.connector.NotPartitionedPartitionHandle.NOT_PARTITIONED; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION; +import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION; import static io.trino.sql.planner.optimizations.PlanNodeSearcher.searchFrom; @@ -168,6 +180,8 @@ public class SqlQueryScheduler private final DynamicFilterService dynamicFilterService; private final TableExecuteContextManager tableExecuteContextManager; private final SplitSourceFactory splitSourceFactory; + private final ExchangeManagerRegistry exchangeManagerRegistry; + private final TaskSourceFactory taskSourceFactory; private final StageManager stageManager; private final CoordinatorStagesScheduler coordinatorStagesScheduler; @@ -204,7 +218,9 @@ public SqlQueryScheduler( TableExecuteContextManager tableExecuteContextManager, Metadata metadata, SplitSourceFactory splitSourceFactory, - TaskManager coordinatorTaskManager) + TaskManager coordinatorTaskManager, + ExchangeManagerRegistry exchangeManagerRegistry, + TaskSourceFactory taskSourceFactory) { this.queryStateMachine = requireNonNull(queryStateMachine, "queryStateMachine is null"); this.nodePartitioningManager = requireNonNull(nodePartitioningManager, "nodePartitioningManager is null"); @@ -218,6 +234,8 @@ public SqlQueryScheduler( this.dynamicFilterService = requireNonNull(dynamicFilterService, "dynamicFilterService is null"); this.tableExecuteContextManager = requireNonNull(tableExecuteContextManager, "tableExecuteContextManager is null"); this.splitSourceFactory = requireNonNull(splitSourceFactory, "splitSourceFactory is null"); + this.exchangeManagerRegistry = requireNonNull(exchangeManagerRegistry, "exchangeManagerRegistry is null"); + this.taskSourceFactory = requireNonNull(taskSourceFactory, "taskSourceFactory is null"); stageManager = StageManager.create( queryStateMachine, @@ -294,25 +312,50 @@ else if (state == QueryState.FAILED) { private synchronized Optional createDistributedStagesScheduler(int attempt) { + verify(attempt == 0 || retryPolicy == RetryPolicy.QUERY, "unexpected attempt %s for retry policy %s", attempt, retryPolicy); if (queryStateMachine.isDone()) { return Optional.empty(); } - DistributedStagesScheduler distributedStagesScheduler = PipelinedDistributedStagesScheduler.create( - queryStateMachine, - schedulerStats, - nodeScheduler, - nodePartitioningManager, - stageManager, - coordinatorStagesScheduler, - executionPolicy, - failureDetector, - schedulerExecutor, - splitSourceFactory, - splitBatchSize, - dynamicFilterService, - tableExecuteContextManager, - retryPolicy, - attempt); + DistributedStagesScheduler distributedStagesScheduler; + switch (retryPolicy) { + case TASK: + ExchangeManager exchangeManager = exchangeManagerRegistry.getExchangeManager(); + distributedStagesScheduler = FaultTolerantDistributedStagesScheduler.create( + queryStateMachine, + stageManager, + failureDetector, + taskSourceFactory, + exchangeManager, + nodePartitioningManager, + coordinatorStagesScheduler.getTaskLifecycleListener(), + maxRetryAttempts, + schedulerExecutor, + schedulerStats, + nodeScheduler); + break; + case QUERY: + case NONE: + distributedStagesScheduler = PipelinedDistributedStagesScheduler.create( + queryStateMachine, + schedulerStats, + nodeScheduler, + nodePartitioningManager, + stageManager, + coordinatorStagesScheduler, + executionPolicy, + failureDetector, + schedulerExecutor, + splitSourceFactory, + splitBatchSize, + dynamicFilterService, + tableExecuteContextManager, + retryPolicy, + attempt); + break; + default: + throw new IllegalArgumentException("Unexpected retry policy: " + retryPolicy); + } + this.distributedStagesScheduler.set(distributedStagesScheduler); distributedStagesScheduler.addStateChangeListener(state -> { if (queryStateMachine.getQueryState() == QueryState.STARTING && state.isRunningOrDone()) { @@ -1597,6 +1640,327 @@ public Optional getFailureCause() } } + private static class FaultTolerantDistributedStagesScheduler + implements DistributedStagesScheduler + { + private final DistributedStagesSchedulerStateMachine stateMachine; + private final QueryStateMachine queryStateMachine; + private final List schedulers; + private final SplitSchedulerStats schedulerStats; + private final NodeAllocator nodeAllocator; + private final ScheduledFuture nodeUpdateTask; + + private final AtomicBoolean started = new AtomicBoolean(); + + public static FaultTolerantDistributedStagesScheduler create( + QueryStateMachine queryStateMachine, + StageManager stageManager, + FailureDetector failureDetector, + TaskSourceFactory taskSourceFactory, + ExchangeManager exchangeManager, + NodePartitioningManager nodePartitioningManager, + TaskLifecycleListener coordinatorTaskLifecycleListener, + int retryAttempts, + ScheduledExecutorService scheduledExecutorService, + SplitSchedulerStats schedulerStats, + NodeScheduler nodeScheduler) + { + DistributedStagesSchedulerStateMachine stateMachine = new DistributedStagesSchedulerStateMachine(queryStateMachine.getQueryId(), scheduledExecutorService); + + Session session = queryStateMachine.getSession(); + int hashPartitionCount = getHashPartitionCount(session); + Map bucketToPartitionCacheMap = new HashMap<>(); + Function bucketToPartitionMapCache = partitioningHandle -> + bucketToPartitionCacheMap.computeIfAbsent(partitioningHandle, handle -> createBucketToPartitionMap(session, hashPartitionCount, handle, nodePartitioningManager)); + + ImmutableList.Builder schedulers = ImmutableList.builder(); + Map exchanges = new HashMap<>(); + + FixedCountNodeAllocator nodeAllocator = new FixedCountNodeAllocator(nodeScheduler, session, 1); + ScheduledFuture nodeUpdateTask = scheduledExecutorService.scheduleAtFixedRate(nodeAllocator::updateNodes, 5, 5, SECONDS); + + try { + // top to bottom order + List distributedStagesInTopologicalOrder = stageManager.getDistributedStagesInTopologicalOrder(); + // bottom to top order + List distributedStagesInReverseTopologicalOrder = reverse(distributedStagesInTopologicalOrder); + + ImmutableSet.Builder coordinatorConsumedFragmentsBuilder = ImmutableSet.builder(); + + for (SqlStage stage : distributedStagesInReverseTopologicalOrder) { + PlanFragment fragment = stage.getFragment(); + Optional parentStage = stageManager.getParent(stage.getStageId()); + TaskLifecycleListener taskLifecycleListener; + Optional exchange; + if (parentStage.isEmpty() || parentStage.get().getFragment().getPartitioning().isCoordinatorOnly()) { + // output will be consumed by coordinator + exchange = Optional.empty(); + taskLifecycleListener = coordinatorTaskLifecycleListener; + coordinatorConsumedFragmentsBuilder.add(fragment.getId()); + } + else { + // create external exchange + exchange = Optional.of(exchangeManager.create(new ExchangeContext(session.getQueryId(), stage.getStageId().getId()), hashPartitionCount)); + exchanges.put(fragment.getId(), exchange.get()); + taskLifecycleListener = TaskLifecycleListener.NO_OP; + } + + ImmutableMap.Builder sourceExchanges = ImmutableMap.builder(); + for (SqlStage childStage : stageManager.getChildren(fragment.getId())) { + PlanFragmentId childFragmentId = childStage.getFragment().getId(); + Exchange sourceExchange = exchanges.get(childFragmentId); + verify(sourceExchange != null, "exchange not found for fragment: %s", childFragmentId); + sourceExchanges.put(childFragmentId, sourceExchange); + } + + BucketToPartition inputBucketToPartition = bucketToPartitionMapCache.apply(fragment.getPartitioning()); + FaultTolerantStageScheduler scheduler = new FaultTolerantStageScheduler( + session, + stage, + failureDetector, + taskSourceFactory, + nodeAllocator, + taskLifecycleListener, + exchange, + bucketToPartitionMapCache.apply(fragment.getPartitioningScheme().getPartitioning().getHandle()).getBucketToPartitionMap(), + sourceExchanges.build(), + inputBucketToPartition.getBucketToPartitionMap(), + inputBucketToPartition.getBucketNodeMap(), + retryAttempts); + + schedulers.add(scheduler); + } + + Set coordinatorConsumedFragments = coordinatorConsumedFragmentsBuilder.build(); + stateMachine.addStateChangeListener(state -> { + if (state == DistributedStagesSchedulerState.FINISHED) { + coordinatorConsumedFragments.forEach(coordinatorTaskLifecycleListener::noMoreTasks); + } + }); + + return new FaultTolerantDistributedStagesScheduler( + stateMachine, + queryStateMachine, + schedulers.build(), + schedulerStats, + nodeAllocator, + nodeUpdateTask); + } + catch (Throwable t) { + schedulers.build().forEach(FaultTolerantStageScheduler::abort); + + nodeUpdateTask.cancel(true); + try { + nodeAllocator.close(); + } + catch (Throwable closeFailure) { + if (t != closeFailure) { + t.addSuppressed(closeFailure); + } + } + + for (Exchange exchange : exchanges.values()) { + try { + exchange.close(); + } + catch (Throwable closeFailure) { + if (t != closeFailure) { + t.addSuppressed(closeFailure); + } + } + } + throw t; + } + } + + private static BucketToPartition createBucketToPartitionMap( + Session session, + int hashPartitionCount, + PartitioningHandle partitioningHandle, + NodePartitioningManager nodePartitioningManager) + { + if (partitioningHandle.equals(FIXED_HASH_DISTRIBUTION)) { + return new BucketToPartition(Optional.of(IntStream.range(0, hashPartitionCount).toArray()), Optional.empty()); + } + else if (partitioningHandle.getConnectorId().isPresent()) { + int partitionCount = hashPartitionCount; + BucketNodeMap bucketNodeMap = nodePartitioningManager.getBucketNodeMap(session, partitioningHandle, true); + if (!bucketNodeMap.isDynamic()) { + partitionCount = bucketNodeMap.getNodeCount(); + } + int bucketCount = bucketNodeMap.getBucketCount(); + int[] bucketToPartition = new int[bucketCount]; + int nextPartitionId = 0; + for (int bucket = 0; bucket < bucketCount; bucket++) { + bucketToPartition[bucket] = nextPartitionId++ % partitionCount; + } + return new BucketToPartition(Optional.of(bucketToPartition), Optional.of(bucketNodeMap)); + } + else { + return new BucketToPartition(Optional.empty(), Optional.empty()); + } + } + + private static class BucketToPartition + { + private final Optional bucketToPartitionMap; + private final Optional bucketNodeMap; + + private BucketToPartition(Optional bucketToPartitionMap, Optional bucketNodeMap) + { + this.bucketToPartitionMap = requireNonNull(bucketToPartitionMap, "bucketToPartitionMap is null"); + this.bucketNodeMap = requireNonNull(bucketNodeMap, "bucketNodeMap is null"); + } + + public Optional getBucketToPartitionMap() + { + return bucketToPartitionMap; + } + + public Optional getBucketNodeMap() + { + return bucketNodeMap; + } + } + + private FaultTolerantDistributedStagesScheduler( + DistributedStagesSchedulerStateMachine stateMachine, + QueryStateMachine queryStateMachine, + List schedulers, + SplitSchedulerStats schedulerStats, + NodeAllocator nodeAllocator, + ScheduledFuture nodeUpdateTask) + { + this.stateMachine = requireNonNull(stateMachine, "stateMachine is null"); + this.queryStateMachine = requireNonNull(queryStateMachine, "queryStateMachine is null"); + this.schedulers = requireNonNull(schedulers, "schedulers is null"); + this.schedulerStats = requireNonNull(schedulerStats, "schedulerStats is null"); + this.nodeAllocator = requireNonNull(nodeAllocator, "nodeAllocator is null"); + this.nodeUpdateTask = requireNonNull(nodeUpdateTask, "nodeUpdateTask is null"); + } + + @Override + public void schedule() + { + checkState(started.compareAndSet(false, true), "already started"); + + try (SetThreadName ignored = new SetThreadName("Query-%s", queryStateMachine.getQueryId())) { + List> blockedStages = new ArrayList<>(); + while (!isFinishingOrDone(queryStateMachine) && !stateMachine.getState().isDone()) { + blockedStages.clear(); + boolean atLeastOneStageIsNotBlocked = false; + boolean allFinished = true; + for (FaultTolerantStageScheduler scheduler : schedulers) { + if (scheduler.isFinished()) { + continue; + } + allFinished = false; + ListenableFuture blocked = scheduler.isBlocked(); + if (!blocked.isDone()) { + blockedStages.add(blocked); + continue; + } + try { + scheduler.schedule(); + } + catch (Throwable t) { + stateMachine.transitionToFailed(t, Optional.of(scheduler.getStageId())); + return; + } + blocked = scheduler.isBlocked(); + if (!blocked.isDone()) { + blockedStages.add(blocked); + } + else { + atLeastOneStageIsNotBlocked = true; + } + } + if (allFinished) { + stateMachine.transitionToFinished(); + return; + } + // wait for a state change and then schedule again + if (!atLeastOneStageIsNotBlocked && !blockedStages.isEmpty()) { + try (TimeStat.BlockTimer timer = schedulerStats.getSleepTime().time()) { + try { + tryGetFutureValue(whenAnyComplete(blockedStages), 1, SECONDS); + } + catch (CancellationException e) { + log.debug("Future cancelled"); + } + } + } + } + } + catch (Throwable t) { + stateMachine.transitionToFailed(t, Optional.empty()); + schedulers.forEach(FaultTolerantStageScheduler::abort); + closeNodeAllocator(); + } + } + + private static boolean isFinishingOrDone(QueryStateMachine queryStateMachine) + { + QueryState queryState = queryStateMachine.getQueryState(); + return queryState == FINISHING || queryState.isDone(); + } + + @Override + public void cancelStage(StageId stageId) + { + // single stage cancellation is not supported in fault tolerant mode + } + + @Override + public void cancel() + { + stateMachine.transitionToCanceled(); + schedulers.forEach(FaultTolerantStageScheduler::cancel); + closeNodeAllocator(); + } + + @Override + public void abort() + { + stateMachine.transitionToAborted(); + schedulers.forEach(FaultTolerantStageScheduler::abort); + closeNodeAllocator(); + } + + private void closeNodeAllocator() + { + nodeUpdateTask.cancel(true); + try { + nodeAllocator.close(); + } + catch (Throwable t) { + log.warn(t, "Error closing node allocator for query: %s", queryStateMachine.getQueryId()); + } + } + + @Override + public void reportTaskFailure(TaskId taskId, Throwable failureCause) + { + for (FaultTolerantStageScheduler scheduler : schedulers) { + if (scheduler.getStageId().equals(taskId.getStageId())) { + scheduler.reportTaskFailure(taskId, failureCause); + } + } + } + + @Override + public void addStateChangeListener(StateChangeListener stateChangeListener) + { + stateMachine.addStateChangeListener(stateChangeListener); + } + + @Override + public Optional getFailureCause() + { + return stateMachine.getFailureCause(); + } + } + private enum DistributedStagesSchedulerState { PLANNED(false, false), diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java new file mode 100644 index 000000000000..4264ae3ee486 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java @@ -0,0 +1,705 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import com.google.common.base.VerifyException; +import com.google.common.collect.ArrayListMultimap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableMultimap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Multimap; +import com.google.common.util.concurrent.ListenableFuture; +import io.airlift.log.Logger; +import io.airlift.units.DataSize; +import io.trino.Session; +import io.trino.connector.CatalogName; +import io.trino.execution.Lifespan; +import io.trino.execution.QueryManagerConfig; +import io.trino.execution.TableExecuteContext; +import io.trino.execution.TableExecuteContextManager; +import io.trino.metadata.Split; +import io.trino.spi.HostAddress; +import io.trino.spi.QueryId; +import io.trino.spi.exchange.Exchange; +import io.trino.spi.exchange.ExchangeSourceHandle; +import io.trino.spi.exchange.ExchangeSourceSplitter; +import io.trino.spi.exchange.ExchangeSourceStatistics; +import io.trino.split.SplitSource; +import io.trino.split.SplitSource.SplitBatch; +import io.trino.sql.planner.PartitioningHandle; +import io.trino.sql.planner.PlanFragment; +import io.trino.sql.planner.SplitSourceFactory; +import io.trino.sql.planner.plan.PlanFragmentId; +import io.trino.sql.planner.plan.PlanNodeId; +import io.trino.sql.planner.plan.RemoteSourceNode; + +import javax.inject.Inject; + +import java.util.ArrayDeque; +import java.util.Collection; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Queue; +import java.util.Set; +import java.util.function.LongConsumer; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Iterables.getOnlyElement; +import static com.google.common.collect.Sets.newIdentityHashSet; +import static com.google.common.collect.Sets.union; +import static io.airlift.concurrent.MoreFutures.addSuccessCallback; +import static io.airlift.concurrent.MoreFutures.getFutureValue; +import static io.trino.SystemSessionProperties.getFaultTolerantExecutionTargetTaskInputSize; +import static io.trino.SystemSessionProperties.getFaultTolerantExecutionTargetTaskSplitCount; +import static io.trino.connector.CatalogName.isInternalSystemConnector; +import static io.trino.spi.connector.NotPartitionedPartitionHandle.NOT_PARTITIONED; +import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION; +import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; +import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_DISTRIBUTION; +import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; +import static io.trino.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION; +import static io.trino.sql.planner.plan.ExchangeNode.Type.GATHER; +import static io.trino.sql.planner.plan.ExchangeNode.Type.REPLICATE; +import static java.util.Objects.requireNonNull; + +public class StageTaskSourceFactory + implements TaskSourceFactory +{ + private static final Logger log = Logger.get(StageTaskSourceFactory.class); + + private final SplitSourceFactory splitSourceFactory; + private final TableExecuteContextManager tableExecuteContextManager; + private final int splitBatchSize; + + @Inject + public StageTaskSourceFactory( + SplitSourceFactory splitSourceFactory, + TableExecuteContextManager tableExecuteContextManager, + QueryManagerConfig queryManagerConfig) + { + this(splitSourceFactory, tableExecuteContextManager, requireNonNull(queryManagerConfig, "queryManagerConfig is null").getScheduleSplitBatchSize()); + } + + public StageTaskSourceFactory( + SplitSourceFactory splitSourceFactory, + TableExecuteContextManager tableExecuteContextManager, + int splitBatchSize) + { + this.splitSourceFactory = requireNonNull(splitSourceFactory, "splitSourceFactory is null"); + this.tableExecuteContextManager = requireNonNull(tableExecuteContextManager, "tableExecuteContextManager is null"); + this.splitBatchSize = splitBatchSize; + } + + @Override + public TaskSource create( + Session session, + PlanFragment fragment, + Map sourceExchanges, + Multimap exchangeSourceHandles, + LongConsumer getSplitTimeRecorder, + Optional bucketToPartitionMap, + Optional bucketNodeMap) + { + PartitioningHandle partitioning = fragment.getPartitioning(); + + if (partitioning.equals(SINGLE_DISTRIBUTION)) { + return SingleDistributionTaskSource.create(fragment, exchangeSourceHandles); + } + else if (partitioning.equals(FIXED_ARBITRARY_DISTRIBUTION) || partitioning.equals(SCALED_WRITER_DISTRIBUTION)) { + return ArbitraryDistributionTaskSource.create( + fragment, + sourceExchanges, + exchangeSourceHandles, + getFaultTolerantExecutionTargetTaskInputSize(session)); + } + else if (partitioning.equals(FIXED_HASH_DISTRIBUTION) || partitioning.getConnectorId().isPresent()) { + return HashDistributionTaskSource.create( + session, + fragment, + splitSourceFactory, + exchangeSourceHandles, + splitBatchSize, + getSplitTimeRecorder, + bucketToPartitionMap, + bucketNodeMap); + } + else if (partitioning.equals(SOURCE_DISTRIBUTION)) { + return SourceDistributionTaskSource.create( + session, + fragment, + exchangeSourceHandles, + splitSourceFactory, + tableExecuteContextManager, + splitBatchSize, + getSplitTimeRecorder, + getFaultTolerantExecutionTargetTaskSplitCount(session)); + } + + // other partitioning handles are not expected to be set as a fragment partitioning + throw new IllegalArgumentException("Unexpected partitioning: " + partitioning); + } + + public static class SingleDistributionTaskSource + implements TaskSource + { + private final Multimap exchangeSourceHandles; + + private boolean finished; + + public static SingleDistributionTaskSource create(PlanFragment fragment, Multimap exchangeSourceHandles) + { + checkArgument(fragment.getPartitionedSources().isEmpty(), "no partitioned sources (table scans) expected, got: %s", fragment.getPartitionedSources()); + return new SingleDistributionTaskSource(getInputsForRemoteSources(fragment.getRemoteSourceNodes(), exchangeSourceHandles)); + } + + public SingleDistributionTaskSource(Multimap exchangeSourceHandles) + { + this.exchangeSourceHandles = ImmutableMultimap.copyOf(requireNonNull(exchangeSourceHandles, "exchangeSourceHandles is null")); + } + + @Override + public List getMoreTasks() + { + List result = ImmutableList.of(new TaskDescriptor( + 0, + ImmutableMultimap.of(), + exchangeSourceHandles, + new NodeRequirements(Optional.empty(), Optional.empty()))); + finished = true; + return result; + } + + @Override + public boolean isFinished() + { + return finished; + } + + @Override + public void close() + { + } + } + + public static class ArbitraryDistributionTaskSource + implements TaskSource + { + private final Map sourceFragmentToRemoteSourceNodeIdMap; + private final Map sourceExchanges; + private final Multimap exchangeSourceHandles; + private final long targetPartitionSizeInBytes; + + private boolean finished; + + public static ArbitraryDistributionTaskSource create( + PlanFragment fragment, + Map sourceExchanges, + Multimap exchangeSourceHandles, + DataSize targetPartitionSize) + { + checkArgument(fragment.getPartitionedSources().isEmpty(), "no partitioned sources (table scans) expected, got: %s", fragment.getPartitionedSources()); + checkArgument(fragment.getRemoteSourceNodes().stream().noneMatch(node -> node.getExchangeType() == REPLICATE), "replicated exchanges are not expected in source distributed stage, got: %s", fragment.getRemoteSourceNodes()); + + return new ArbitraryDistributionTaskSource( + getSourceFragmentToRemoteSourceNodeIdMap(fragment.getRemoteSourceNodes()), + sourceExchanges, + exchangeSourceHandles, + targetPartitionSize); + } + + public ArbitraryDistributionTaskSource( + Map sourceFragmentToRemoteSourceNodeIdMap, + Map sourceExchanges, + Multimap exchangeSourceHandles, + DataSize targetPartitionSize) + { + this.sourceFragmentToRemoteSourceNodeIdMap = ImmutableMap.copyOf(requireNonNull(sourceFragmentToRemoteSourceNodeIdMap, "sourceFragmentToRemoteSourceNodeIdMap is null")); + this.sourceExchanges = ImmutableMap.copyOf(requireNonNull(sourceExchanges, "sourceExchanges is null")); + this.exchangeSourceHandles = ImmutableListMultimap.copyOf(requireNonNull(exchangeSourceHandles, "exchangeSourceHandles is null")); + this.targetPartitionSizeInBytes = requireNonNull(targetPartitionSize, "targetPartitionSize is null").toBytes(); + } + + @Override + public List getMoreTasks() + { + NodeRequirements nodeRequirements = new NodeRequirements(Optional.empty(), Optional.empty()); + + ImmutableList.Builder result = ImmutableList.builder(); + int currentPartitionId = 0; + + ImmutableListMultimap.Builder assignedExchangeSourceHandles = ImmutableListMultimap.builder(); + long assignedExchangeDataSize = 0; + + for (Map.Entry entry : exchangeSourceHandles.entries()) { + PlanFragmentId sourceFragmentId = entry.getKey(); + PlanNodeId remoteSourcePlanNodeId = sourceFragmentToRemoteSourceNodeIdMap.get(sourceFragmentId); + ExchangeSourceHandle originalExchangeSourceHandle = entry.getValue(); + Exchange sourceExchange = sourceExchanges.get(sourceFragmentId); + + ExchangeSourceSplitter splitter = sourceExchange.split(originalExchangeSourceHandle, targetPartitionSizeInBytes); + ImmutableList.Builder sourceHandles = ImmutableList.builder(); + while (true) { + checkState(splitter.isBlocked().isDone(), "not supported"); + Optional next = splitter.getNext(); + if (next.isEmpty()) { + break; + } + sourceHandles.add(next.get()); + } + + for (ExchangeSourceHandle handle : sourceHandles.build()) { + ExchangeSourceStatistics statistics = sourceExchange.getExchangeSourceStatistics(handle); + if (assignedExchangeDataSize != 0 && assignedExchangeDataSize + statistics.getSizeInBytes() > targetPartitionSizeInBytes) { + result.add(new TaskDescriptor(currentPartitionId++, ImmutableListMultimap.of(), assignedExchangeSourceHandles.build(), nodeRequirements)); + assignedExchangeSourceHandles = ImmutableListMultimap.builder(); + assignedExchangeDataSize = 0; + } + + assignedExchangeSourceHandles.put(remoteSourcePlanNodeId, handle); + assignedExchangeDataSize += statistics.getSizeInBytes(); + } + } + + if (assignedExchangeDataSize != 0) { + result.add(new TaskDescriptor(currentPartitionId, ImmutableListMultimap.of(), assignedExchangeSourceHandles.build(), nodeRequirements)); + } + + finished = true; + return result.build(); + } + + @Override + public boolean isFinished() + { + return finished; + } + + @Override + public void close() + { + } + } + + public static class HashDistributionTaskSource + implements TaskSource + { + private final Map splitSources; + private final Multimap partitionedExchangeSourceHandles; + private final Multimap replicatedExchangeSourceHandles; + private final int splitBatchSize; + private final LongConsumer getSplitTimeRecorder; + private final Optional bucketToPartitionMap; + private final Optional bucketNodeMap; + private final Optional catalogRequirement; + + private boolean finished; + private boolean closed; + + public static HashDistributionTaskSource create( + Session session, + PlanFragment fragment, + SplitSourceFactory splitSourceFactory, + Multimap exchangeSourceHandles, + int splitBatchSize, + LongConsumer getSplitTimeRecorder, + Optional bucketToPartitionMap, + Optional bucketNodeMap) + { + checkArgument(bucketNodeMap.isPresent() || fragment.getPartitionedSources().isEmpty(), "bucketNodeMap is expected to be set"); + Map splitSources = splitSourceFactory.createSplitSources(session, fragment); + return new HashDistributionTaskSource( + splitSources, + getPartitionedExchangeSourceHandles(fragment, exchangeSourceHandles), + getReplicatedExchangeSourceHandles(fragment, exchangeSourceHandles), + splitBatchSize, + getSplitTimeRecorder, + bucketToPartitionMap, + bucketNodeMap, + fragment.getPartitioning().getConnectorId()); + } + + public HashDistributionTaskSource( + Map splitSources, + Multimap partitionedExchangeSourceHandles, + Multimap replicatedExchangeSourceHandles, + int splitBatchSize, + LongConsumer getSplitTimeRecorder, + Optional bucketToPartitionMap, + Optional bucketNodeMap, + Optional catalogRequirement) + { + this.splitSources = ImmutableMap.copyOf(requireNonNull(splitSources, "splitSources is null")); + this.partitionedExchangeSourceHandles = ImmutableListMultimap.copyOf(requireNonNull(partitionedExchangeSourceHandles, "partitionedExchangeSourceHandles is null")); + this.replicatedExchangeSourceHandles = ImmutableListMultimap.copyOf(requireNonNull(replicatedExchangeSourceHandles, "replicatedExchangeSourceHandles is null")); + this.splitBatchSize = splitBatchSize; + this.getSplitTimeRecorder = requireNonNull(getSplitTimeRecorder, "getSplitTimeRecorder is null"); + this.bucketToPartitionMap = requireNonNull(bucketToPartitionMap, "bucketToPartitionMap is null"); + checkArgument(bucketToPartitionMap.isPresent() || partitionedExchangeSourceHandles.isEmpty() || splitSources.isEmpty(), "bucket to partition map is expected to be set when partitioned inputs are present"); + this.bucketNodeMap = requireNonNull(bucketNodeMap, "bucketNodeMap is null"); + checkArgument(splitSources.isEmpty() || bucketNodeMap.isPresent(), "splitToBucketFunction must be set when split sources are present"); + this.catalogRequirement = requireNonNull(catalogRequirement, "catalogRequirement is null"); + } + + @Override + public List getMoreTasks() + { + if (finished || closed) { + return ImmutableList.of(); + } + + Map> partitionToSplitsMap = new HashMap<>(); + Map partitionToNodeMap = new HashMap<>(); + if (!splitSources.isEmpty()) { + for (Map.Entry entry : splitSources.entrySet()) { + PlanNodeId scanNodeId = entry.getKey(); + SplitSource splitSource = entry.getValue(); + BucketNodeMap bucketNodeMap = this.bucketNodeMap + .orElseThrow(() -> new VerifyException("bucket to node map is expected to be present")); + while (!splitSource.isFinished()) { + ListenableFuture splitBatchFuture = splitSource.getNextBatch(NOT_PARTITIONED, Lifespan.taskWide(), splitBatchSize); + + long start = System.nanoTime(); + addSuccessCallback(splitBatchFuture, () -> getSplitTimeRecorder.accept(start)); + + SplitBatch splitBatch = getFutureValue(splitBatchFuture); + + for (Split split : splitBatch.getSplits()) { + int bucket = bucketNodeMap.getBucket(split); + int partition = bucketToPartitionMap.map(map -> map[bucket]).orElse(bucket); + + if (!bucketNodeMap.isDynamic()) { + partitionToNodeMap.put(partition, bucketNodeMap.getAssignedNode(split).get().getHostAndPort()); + } + + Multimap partitionSplits = partitionToSplitsMap.computeIfAbsent(partition, (p) -> ArrayListMultimap.create()); + partitionSplits.put(scanNodeId, split); + } + + if (splitBatch.isLastBatch()) { + splitSource.close(); + break; + } + } + } + } + + Map> partitionToExchangeSourceHandlesMap = new HashMap<>(); + for (Map.Entry entry : partitionedExchangeSourceHandles.entries()) { + PlanNodeId planNodeId = entry.getKey(); + ExchangeSourceHandle handle = entry.getValue(); + int partition = handle.getPartitionId(); + Multimap partitionSourceHandles = partitionToExchangeSourceHandlesMap.computeIfAbsent(partition, (p) -> ArrayListMultimap.create()); + partitionSourceHandles.put(planNodeId, handle); + } + + // TODO implement small tasks merging + + int taskPartitionId = 0; + ImmutableList.Builder result = ImmutableList.builder(); + for (Integer partition : union(partitionToSplitsMap.keySet(), partitionToExchangeSourceHandlesMap.keySet())) { + Multimap splits = partitionToSplitsMap.getOrDefault(partition, ImmutableMultimap.of()); + Multimap exchangeSourceHandles = ImmutableListMultimap.builder() + .putAll(partitionToExchangeSourceHandlesMap.getOrDefault(partition, ImmutableMultimap.of())) + .putAll(replicatedExchangeSourceHandles) + .build(); + Optional> hostRequirement = Optional.ofNullable(partitionToNodeMap.get(partition)).map(ImmutableSet::of); + result.add(new TaskDescriptor(taskPartitionId++, splits, exchangeSourceHandles, new NodeRequirements(catalogRequirement, hostRequirement))); + } + + finished = true; + return result.build(); + } + + @Override + public boolean isFinished() + { + return finished; + } + + @Override + public void close() + { + if (closed) { + return; + } + closed = true; + for (SplitSource splitSource : splitSources.values()) { + try { + splitSource.close(); + } + catch (RuntimeException e) { + log.error(e, "Error closing split source"); + } + } + } + } + + public static class SourceDistributionTaskSource + implements TaskSource + { + private final QueryId queryId; + private final PlanNodeId partitionedSourceNodeId; + private final TableExecuteContextManager tableExecuteContextManager; + private final SplitSource splitSource; + private final Multimap replicatedExchangeSourceHandles; + private final int splitBatchSize; + private final LongConsumer getSplitTimeRecorder; + private final Optional catalogRequirement; + private final int targetPartitionSplitCount; + + private final Queue remotelyAccessibleSplitBuffer = new ArrayDeque<>(); + private final Map> locallyAccessibleSplitBuffer = new HashMap<>(); + + private int currentPartitionId; + private boolean finished; + private boolean closed; + + public static SourceDistributionTaskSource create( + Session session, + PlanFragment fragment, + Multimap exchangeSourceHandles, + SplitSourceFactory splitSourceFactory, + TableExecuteContextManager tableExecuteContextManager, + int splitBatchSize, + LongConsumer getSplitTimeRecorder, + int targetPartitionSplitCount) + { + checkArgument(fragment.getPartitionedSources().size() == 1, "single partitioned source is expected, got: %s", fragment.getPartitionedSources()); + + List remoteSourceNodes = fragment.getRemoteSourceNodes(); + checkArgument(remoteSourceNodes.stream().allMatch(node -> node.getExchangeType() == REPLICATE), "only replicated exchanges are expected in source distributed stage, got: %s", remoteSourceNodes); + + PlanNodeId partitionedSourceNodeId = getOnlyElement(fragment.getPartitionedSources()); + Map splitSources = splitSourceFactory.createSplitSources(session, fragment); + SplitSource splitSource = splitSources.get(partitionedSourceNodeId); + + Optional catalogName = Optional.of(splitSource.getCatalogName()) + .filter(catalog -> !isInternalSystemConnector(catalog)); + + return new SourceDistributionTaskSource( + session.getQueryId(), + partitionedSourceNodeId, + tableExecuteContextManager, + splitSource, + getReplicatedExchangeSourceHandles(fragment, exchangeSourceHandles), + splitBatchSize, + getSplitTimeRecorder, + catalogName, + targetPartitionSplitCount); + } + + public SourceDistributionTaskSource( + QueryId queryId, + PlanNodeId partitionedSourceNodeId, + TableExecuteContextManager tableExecuteContextManager, + SplitSource splitSource, + Multimap replicatedExchangeSourceHandles, + int splitBatchSize, + LongConsumer getSplitTimeRecorder, + Optional catalogRequirement, + int targetPartitionSplitCount) + { + this.queryId = requireNonNull(queryId, "queryId is null"); + this.partitionedSourceNodeId = requireNonNull(partitionedSourceNodeId, "partitionedSourceNodeId is null"); + this.tableExecuteContextManager = requireNonNull(tableExecuteContextManager, "tableExecuteContextManager is null"); + this.splitSource = requireNonNull(splitSource, "splitSource is null"); + this.replicatedExchangeSourceHandles = ImmutableListMultimap.copyOf(requireNonNull(replicatedExchangeSourceHandles, "replicatedExchangeSourceHandles is null")); + this.splitBatchSize = splitBatchSize; + this.getSplitTimeRecorder = requireNonNull(getSplitTimeRecorder, "getSplitTimeRecorder is null"); + this.catalogRequirement = requireNonNull(catalogRequirement, "catalogRequirement is null"); + checkArgument(targetPartitionSplitCount > 0, "targetPartitionSplitCount must be positive: %s", targetPartitionSplitCount); + this.targetPartitionSplitCount = targetPartitionSplitCount; + } + + @Override + public List getMoreTasks() + { + if (finished || closed) { + return ImmutableList.of(); + } + + while (true) { + if (remotelyAccessibleSplitBuffer.size() >= targetPartitionSplitCount) { + ImmutableList.Builder splits = ImmutableList.builder(); + for (int i = 0; i < targetPartitionSplitCount; i++) { + splits.add(remotelyAccessibleSplitBuffer.poll()); + } + return ImmutableList.of( + new TaskDescriptor( + currentPartitionId++, + ImmutableListMultimap.builder().putAll(partitionedSourceNodeId, splits.build()).build(), + replicatedExchangeSourceHandles, + new NodeRequirements(catalogRequirement, Optional.empty()))); + } + for (HostAddress remoteHost : locallyAccessibleSplitBuffer.keySet()) { + Set hostSplits = locallyAccessibleSplitBuffer.get(remoteHost); + if (hostSplits.size() >= targetPartitionSplitCount) { + List splits = removeN(hostSplits, targetPartitionSplitCount); + locallyAccessibleSplitBuffer.values().forEach(values -> splits.forEach(values::remove)); + return ImmutableList.of( + new TaskDescriptor( + currentPartitionId++, + ImmutableListMultimap.builder().putAll(partitionedSourceNodeId, splits).build(), + replicatedExchangeSourceHandles, + new NodeRequirements(catalogRequirement, Optional.of(ImmutableSet.of(remoteHost))))); + } + } + + if (splitSource.isFinished()) { + break; + } + + ListenableFuture splitBatchFuture = splitSource.getNextBatch(NOT_PARTITIONED, Lifespan.taskWide(), splitBatchSize); + + long start = System.nanoTime(); + addSuccessCallback(splitBatchFuture, () -> getSplitTimeRecorder.accept(start)); + + List splits = getFutureValue(splitBatchFuture).getSplits(); + + for (Split split : splits) { + if (split.isRemotelyAccessible()) { + remotelyAccessibleSplitBuffer.add(split); + } + else { + List addresses = split.getAddresses(); + checkArgument(!addresses.isEmpty(), "split is not remotely accessible but the list of addresses is empty"); + for (HostAddress hostAddress : addresses) { + locallyAccessibleSplitBuffer.computeIfAbsent(hostAddress, key -> newIdentityHashSet()).add(split); + } + } + } + } + + Optional> tableExecuteSplitsInfo = splitSource.getTableExecuteSplitsInfo(); + + // Here we assume that we can get non-empty tableExecuteSplitsInfo only for queries which facilitate single split source. + tableExecuteSplitsInfo.ifPresent(info -> { + TableExecuteContext tableExecuteContext = tableExecuteContextManager.getTableExecuteContextForQuery(queryId); + tableExecuteContext.setSplitsInfo(info); + }); + + finished = true; + splitSource.close(); + + ImmutableList.Builder result = ImmutableList.builder(); + + if (!remotelyAccessibleSplitBuffer.isEmpty()) { + result.add(new TaskDescriptor( + currentPartitionId++, + ImmutableListMultimap.builder().putAll(partitionedSourceNodeId, ImmutableList.copyOf(remotelyAccessibleSplitBuffer)).build(), + replicatedExchangeSourceHandles, + new NodeRequirements(catalogRequirement, Optional.empty()))); + remotelyAccessibleSplitBuffer.clear(); + } + + if (!locallyAccessibleSplitBuffer.isEmpty()) { + for (HostAddress remoteHost : locallyAccessibleSplitBuffer.keySet()) { + List splits = ImmutableList.copyOf(locallyAccessibleSplitBuffer.get(remoteHost)); + if (!splits.isEmpty()) { + locallyAccessibleSplitBuffer.values().forEach(values -> splits.forEach(values::remove)); + result.add(new TaskDescriptor( + currentPartitionId++, + ImmutableListMultimap.builder().putAll(partitionedSourceNodeId, splits).build(), + replicatedExchangeSourceHandles, + new NodeRequirements(catalogRequirement, Optional.of(ImmutableSet.of(remoteHost))))); + } + } + locallyAccessibleSplitBuffer.clear(); + } + + return result.build(); + } + + private static List removeN(Collection collection, int n) + { + ImmutableList.Builder result = ImmutableList.builder(); + Iterator iterator = collection.iterator(); + for (int i = 0; i < n && iterator.hasNext(); i++) { + T item = iterator.next(); + iterator.remove(); + result.add(item); + } + return result.build(); + } + + @Override + public boolean isFinished() + { + return finished; + } + + @Override + public void close() + { + if (closed) { + return; + } + closed = true; + splitSource.close(); + } + } + + private static Multimap getReplicatedExchangeSourceHandles(PlanFragment fragment, Multimap handles) + { + return getInputsForRemoteSources( + fragment.getRemoteSourceNodes().stream() + .filter(remoteSource -> remoteSource.getExchangeType() == REPLICATE) + .collect(toImmutableList()), + handles); + } + + private static Multimap getPartitionedExchangeSourceHandles(PlanFragment fragment, Multimap handles) + { + return getInputsForRemoteSources( + fragment.getRemoteSourceNodes().stream() + .filter(remoteSource -> remoteSource.getExchangeType() != REPLICATE) + .collect(toImmutableList()), + handles); + } + + private static Map getSourceFragmentToRemoteSourceNodeIdMap(List remoteSourceNodes) + { + ImmutableMap.Builder result = ImmutableMap.builder(); + for (RemoteSourceNode node : remoteSourceNodes) { + for (PlanFragmentId sourceFragmentId : node.getSourceFragmentIds()) { + result.put(sourceFragmentId, node.getId()); + } + } + return result.build(); + } + + private static Multimap getInputsForRemoteSources( + List remoteSources, + Multimap exchangeSourceHandles) + { + ImmutableListMultimap.Builder result = ImmutableListMultimap.builder(); + for (RemoteSourceNode remoteSource : remoteSources) { + for (PlanFragmentId fragmentId : remoteSource.getSourceFragmentIds()) { + Collection handles = requireNonNull(exchangeSourceHandles.get(fragmentId), () -> "exchange source handle is missing for fragment: " + fragmentId); + if (remoteSource.getExchangeType() == GATHER || remoteSource.getExchangeType() == REPLICATE) { + checkArgument(handles.size() <= 1, "at most 1 exchange source handle is expected, got: %s", handles); + } + result.putAll(remoteSource.getId(), handles); + } + } + return result.build(); + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/TaskDescriptor.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/TaskDescriptor.java new file mode 100644 index 000000000000..6d80fb56b161 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/TaskDescriptor.java @@ -0,0 +1,95 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import com.google.common.collect.ImmutableMultimap; +import com.google.common.collect.Multimap; +import io.trino.metadata.Split; +import io.trino.spi.exchange.ExchangeSourceHandle; +import io.trino.sql.planner.plan.PlanNodeId; + +import java.util.Objects; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public class TaskDescriptor +{ + private final int partitionId; + private final Multimap splits; + private final Multimap exchangeSourceHandles; + private final NodeRequirements nodeRequirements; + + public TaskDescriptor( + int partitionId, + Multimap splits, + Multimap exchangeSourceHandles, + NodeRequirements nodeRequirements) + { + this.partitionId = partitionId; + this.splits = ImmutableMultimap.copyOf(requireNonNull(splits, "splits is null")); + this.exchangeSourceHandles = ImmutableMultimap.copyOf(requireNonNull(exchangeSourceHandles, "exchangeSourceHandles is null")); + this.nodeRequirements = requireNonNull(nodeRequirements, "nodeRequirements is null"); + } + + public int getPartitionId() + { + return partitionId; + } + + public Multimap getSplits() + { + return splits; + } + + public Multimap getExchangeSourceHandles() + { + return exchangeSourceHandles; + } + + public NodeRequirements getNodeRequirements() + { + return nodeRequirements; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + TaskDescriptor that = (TaskDescriptor) o; + return partitionId == that.partitionId && Objects.equals(splits, that.splits) && Objects.equals(exchangeSourceHandles, that.exchangeSourceHandles) && Objects.equals(nodeRequirements, that.nodeRequirements); + } + + @Override + public int hashCode() + { + return Objects.hash(partitionId, splits, exchangeSourceHandles, nodeRequirements); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("partitionId", partitionId) + .add("splits", splits) + .add("exchangeSourceHandles", exchangeSourceHandles) + .add("nodeRequirements", nodeRequirements) + .toString(); + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/TaskSource.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/TaskSource.java new file mode 100644 index 000000000000..d7891c9b9b0f --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/TaskSource.java @@ -0,0 +1,28 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import java.io.Closeable; +import java.util.List; + +public interface TaskSource + extends Closeable +{ + List getMoreTasks(); + + boolean isFinished(); + + @Override + void close(); +} diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/TaskSourceFactory.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/TaskSourceFactory.java new file mode 100644 index 000000000000..70f610a8aeeb --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/TaskSourceFactory.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import com.google.common.collect.Multimap; +import io.trino.Session; +import io.trino.spi.exchange.Exchange; +import io.trino.spi.exchange.ExchangeSourceHandle; +import io.trino.sql.planner.PlanFragment; +import io.trino.sql.planner.plan.PlanFragmentId; + +import java.util.Map; +import java.util.Optional; +import java.util.function.LongConsumer; + +public interface TaskSourceFactory +{ + TaskSource create( + Session session, + PlanFragment fragment, + Map sourceExchanges, + Multimap exchangeSourceHandles, + LongConsumer getSplitTimeRecorder, + Optional bucketToPartitionMap, + Optional bucketNodeMap); +} diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/group/DynamicBucketNodeMap.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/group/DynamicBucketNodeMap.java index 543f21018be0..0bf1885c8ca9 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/group/DynamicBucketNodeMap.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/group/DynamicBucketNodeMap.java @@ -51,6 +51,12 @@ public int getBucketCount() return bucketCount; } + @Override + public int getNodeCount() + { + throw new UnsupportedOperationException(); + } + @Override public void assignBucketToNode(int bucketedId, InternalNode node) { diff --git a/core/trino-main/src/main/java/io/trino/metadata/HandleJsonModule.java b/core/trino-main/src/main/java/io/trino/metadata/HandleJsonModule.java index 8e8c3b405b76..042a6cf59414 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/HandleJsonModule.java +++ b/core/trino-main/src/main/java/io/trino/metadata/HandleJsonModule.java @@ -27,6 +27,8 @@ import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableLayoutHandle; import io.trino.spi.connector.ConnectorTransactionHandle; +import io.trino.spi.exchange.ExchangeSinkInstanceHandle; +import io.trino.spi.exchange.ExchangeSourceHandle; public class HandleJsonModule implements Module @@ -96,4 +98,16 @@ public static com.fasterxml.jackson.databind.Module partitioningHandleModule(Han { return new AbstractTypedJacksonModule<>(ConnectorPartitioningHandle.class, resolver::getId, resolver::getPartitioningHandleClass) {}; } + + @ProvidesIntoSet + public static com.fasterxml.jackson.databind.Module exchangeSinkInstanceHandleModule(HandleResolver resolver) + { + return new AbstractTypedJacksonModule<>(ExchangeSinkInstanceHandle.class, (clazz) -> clazz.getClass().getSimpleName(), (ignored) -> resolver.getExchangeSinkInstanceHandleClass()) {}; + } + + @ProvidesIntoSet + public static com.fasterxml.jackson.databind.Module exchangeSourceHandleModule(HandleResolver resolver) + { + return new AbstractTypedJacksonModule<>(ExchangeSourceHandle.class, (clazz) -> clazz.getClass().getSimpleName(), (ignored) -> resolver.getExchangeSourceHandleHandleClass()) {}; + } } diff --git a/core/trino-main/src/main/java/io/trino/metadata/HandleResolver.java b/core/trino-main/src/main/java/io/trino/metadata/HandleResolver.java index bd8e834f2a27..c8018fed67d6 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/HandleResolver.java +++ b/core/trino-main/src/main/java/io/trino/metadata/HandleResolver.java @@ -26,6 +26,9 @@ import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableLayoutHandle; import io.trino.spi.connector.ConnectorTransactionHandle; +import io.trino.spi.exchange.ExchangeManagerHandleResolver; +import io.trino.spi.exchange.ExchangeSinkInstanceHandle; +import io.trino.spi.exchange.ExchangeSourceHandle; import io.trino.split.EmptySplitHandleResolver; import javax.inject.Inject; @@ -35,6 +38,7 @@ import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import java.util.function.Supplier; @@ -45,7 +49,8 @@ public final class HandleResolver { - private final ConcurrentMap handleResolvers = new ConcurrentHashMap<>(); + private final ConcurrentMap catalogHandleResolvers = new ConcurrentHashMap<>(); + private final AtomicReference exchangeManagerHandleResolver = new AtomicReference<>(); @Inject public HandleResolver() @@ -60,13 +65,18 @@ public void addCatalogHandleResolver(String catalogName, ConnectorHandleResolver { requireNonNull(catalogName, "catalogName is null"); requireNonNull(resolver, "resolver is null"); - MaterializedHandleResolver existingResolver = handleResolvers.putIfAbsent(catalogName, new MaterializedHandleResolver(resolver)); + MaterializedHandleResolver existingResolver = catalogHandleResolvers.putIfAbsent(catalogName, new MaterializedHandleResolver(resolver)); checkState(existingResolver == null, "Catalog '%s' is already assigned to resolver: %s", catalogName, existingResolver); } + public void setExchangeManagerHandleResolver(ExchangeManagerHandleResolver resolver) + { + checkState(exchangeManagerHandleResolver.compareAndSet(null, resolver), "Exchange manager handle resolver is already set"); + } + public void removeCatalogHandleResolver(String catalogName) { - handleResolvers.remove(catalogName); + catalogHandleResolvers.remove(catalogName); } public String getId(ConnectorTableHandle tableHandle) @@ -169,16 +179,30 @@ public Class getTransactionHandleClass(Str return resolverFor(id).getTransactionHandleClass().orElseThrow(() -> new IllegalArgumentException("No resolver for " + id)); } + public Class getExchangeSinkInstanceHandleClass() + { + ExchangeManagerHandleResolver resolver = exchangeManagerHandleResolver.get(); + checkState(resolver != null, "Exchange manager handle resolver is not set"); + return resolver.getExchangeSinkInstanceHandleClass(); + } + + public Class getExchangeSourceHandleHandleClass() + { + ExchangeManagerHandleResolver resolver = exchangeManagerHandleResolver.get(); + checkState(resolver != null, "Exchange manager handle resolver is not set"); + return resolver.getExchangeSourceHandleHandleClass(); + } + private MaterializedHandleResolver resolverFor(String id) { - MaterializedHandleResolver resolver = handleResolvers.get(id); + MaterializedHandleResolver resolver = catalogHandleResolvers.get(id); checkArgument(resolver != null, "No handle resolver for connector: %s", id); return resolver; } private String getId(T handle, Function>> getter) { - for (Entry entry : handleResolvers.entrySet()) { + for (Entry entry : catalogHandleResolvers.entrySet()) { try { if (getter.apply(entry.getValue()).map(clazz -> clazz.isInstance(handle)).orElse(false)) { return entry.getKey(); diff --git a/core/trino-main/src/main/java/io/trino/operator/DeduplicationExchangeClientBuffer.java b/core/trino-main/src/main/java/io/trino/operator/DeduplicatingDirectExchangeBuffer.java similarity index 76% rename from core/trino-main/src/main/java/io/trino/operator/DeduplicationExchangeClientBuffer.java rename to core/trino-main/src/main/java/io/trino/operator/DeduplicatingDirectExchangeBuffer.java index 54c8c13522f2..3947d04ad41c 100644 --- a/core/trino-main/src/main/java/io/trino/operator/DeduplicationExchangeClientBuffer.java +++ b/core/trino-main/src/main/java/io/trino/operator/DeduplicatingDirectExchangeBuffer.java @@ -31,6 +31,7 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.Executor; +import java.util.function.Predicate; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; @@ -38,14 +39,14 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.util.concurrent.Futures.nonCancellationPropagating; +import static io.trino.operator.RetryPolicy.NONE; import static io.trino.operator.RetryPolicy.QUERY; -import static io.trino.operator.RetryPolicy.TASK; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static java.lang.Math.max; import static java.util.Objects.requireNonNull; -public class DeduplicationExchangeClientBuffer - implements ExchangeClientBuffer +public class DeduplicatingDirectExchangeBuffer + implements DirectExchangeBuffer { private final Executor executor; private final long bufferCapacityInBytes; @@ -79,12 +80,12 @@ public class DeduplicationExchangeClientBuffer @GuardedBy("this") private boolean closed; - public DeduplicationExchangeClientBuffer(Executor executor, DataSize bufferCapacity, RetryPolicy retryPolicy) + public DeduplicatingDirectExchangeBuffer(Executor executor, DataSize bufferCapacity, RetryPolicy retryPolicy) { this.executor = requireNonNull(executor, "executor is null"); this.bufferCapacityInBytes = requireNonNull(bufferCapacity, "bufferCapacity is null").toBytes(); requireNonNull(retryPolicy, "retryPolicy is null"); - checkArgument(retryPolicy == QUERY, "retryPolicy is expected to be QUERY: %s", retryPolicy); + checkArgument(retryPolicy != NONE, "retryPolicy is not expected to be NONE"); this.retryPolicy = retryPolicy; } @@ -188,10 +189,6 @@ public synchronized void taskFinished(TaskId taskId) checkState(!failedTasks.containsKey(taskId), "task is failed: %s", taskId); checkState(successfulTasks.add(taskId), "task is finished: %s", taskId); - if (retryPolicy == TASK) { - // TODO implement deduplication for task level retries - throw new UnsupportedOperationException("task level retry policy is unsupported"); - } checkInputFinished(); } @@ -239,10 +236,46 @@ private synchronized void checkInputFinished() return; } + List failures; switch (retryPolicy) { - case TASK: - // TODO implement deduplication for task level retries - throw new UnsupportedOperationException("task level retry policy is unsupported"); + case TASK: { + Set allPartitions = allTasks.stream() + .map(TaskId::getPartitionId) + .collect(toImmutableSet()); + + Set successfulPartitions = successfulTasks.stream() + .map(TaskId::getPartitionId) + .collect(toImmutableSet()); + + if (successfulPartitions.containsAll(allPartitions)) { + Map partitionToTaskMap = new HashMap<>(); + for (TaskId successfulTaskId : successfulTasks) { + Integer partitionId = successfulTaskId.getPartitionId(); + TaskId existing = partitionToTaskMap.get(partitionId); + if (existing == null || existing.getAttemptId() > successfulTaskId.getAttemptId()) { + partitionToTaskMap.put(partitionId, successfulTaskId); + } + } + + removePagesFor(taskId -> !taskId.equals(partitionToTaskMap.get(taskId.getPartitionId()))); + inputFinished = true; + unblock(blocked); + return; + } + + Set runningPartitions = allTasks.stream() + .filter(taskId -> !successfulTasks.contains(taskId)) + .filter(taskId -> !failedTasks.containsKey(taskId)) + .map(TaskId::getPartitionId) + .collect(toImmutableSet()); + + failures = failedTasks.entrySet().stream() + .filter(entry -> !successfulPartitions.contains(entry.getKey().getPartitionId())) + .filter(entry -> !runningPartitions.contains(entry.getKey().getPartitionId())) + .map(Map.Entry::getValue) + .collect(toImmutableList()); + break; + } case QUERY: { Set latestAttemptTasks = allTasks.stream() .filter(taskId -> taskId.getAttemptId() == maxAttemptId) @@ -255,41 +288,46 @@ private synchronized void checkInputFinished() return; } - List failures = failedTasks.entrySet().stream() + failures = failedTasks.entrySet().stream() .filter(entry -> entry.getKey().getAttemptId() == maxAttemptId) .map(Map.Entry::getValue) .collect(toImmutableList()); - - if (!failures.isEmpty()) { - Throwable failure = null; - for (Throwable taskFailure : failures) { - if (failure == null) { - failure = taskFailure; - } - else if (failure != taskFailure) { - failure.addSuppressed(taskFailure); - } - } - pageBuffer.clear(); - bufferRetainedSizeInBytes = 0; - this.failure = failure; - unblock(blocked); - } break; } default: throw new UnsupportedOperationException("unexpected retry policy: " + retryPolicy); } + + if (!failures.isEmpty()) { + Throwable failure = null; + for (Throwable taskFailure : failures) { + if (failure == null) { + failure = taskFailure; + } + else if (failure != taskFailure) { + failure.addSuppressed(taskFailure); + } + } + pageBuffer.clear(); + bufferRetainedSizeInBytes = 0; + this.failure = failure; + unblock(blocked); + } } private synchronized void removePagesForPreviousAttempts(int currentAttemptId) { - // wipe previous attempt pages + removePagesFor(task -> task.getAttemptId() < currentAttemptId); + } + + private synchronized void removePagesFor(Predicate taskIdPredicate) + { long pagesRetainedSizeInBytes = 0; Iterator> iterator = pageBuffer.entries().iterator(); while (iterator.hasNext()) { Map.Entry entry = iterator.next(); - if (entry.getKey().getAttemptId() < currentAttemptId) { + TaskId taskId = entry.getKey(); + if (taskIdPredicate.test(taskId)) { pagesRetainedSizeInBytes += entry.getValue().getRetainedSizeInBytes(); iterator.remove(); } diff --git a/core/trino-main/src/main/java/io/trino/operator/ExchangeClientBuffer.java b/core/trino-main/src/main/java/io/trino/operator/DirectExchangeBuffer.java similarity index 97% rename from core/trino-main/src/main/java/io/trino/operator/ExchangeClientBuffer.java rename to core/trino-main/src/main/java/io/trino/operator/DirectExchangeBuffer.java index 7990a63a510b..7e4eae47a317 100644 --- a/core/trino-main/src/main/java/io/trino/operator/ExchangeClientBuffer.java +++ b/core/trino-main/src/main/java/io/trino/operator/DirectExchangeBuffer.java @@ -20,7 +20,7 @@ import java.io.Closeable; import java.util.List; -public interface ExchangeClientBuffer +public interface DirectExchangeBuffer extends Closeable { /** diff --git a/core/trino-main/src/main/java/io/trino/operator/ExchangeClient.java b/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClient.java similarity index 95% rename from core/trino-main/src/main/java/io/trino/operator/ExchangeClient.java rename to core/trino-main/src/main/java/io/trino/operator/DirectExchangeClient.java index 15d793b4200a..d5df9b4617df 100644 --- a/core/trino-main/src/main/java/io/trino/operator/ExchangeClient.java +++ b/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClient.java @@ -47,7 +47,7 @@ import static java.util.Objects.requireNonNull; @ThreadSafe -public class ExchangeClient +public class DirectExchangeClient implements Closeable { private final String selfAddress; @@ -68,7 +68,7 @@ public class ExchangeClient private final Deque queuedClients = new LinkedList<>(); private final Set completedClients = newConcurrentHashSet(); - private final ExchangeClientBuffer buffer; + private final DirectExchangeBuffer buffer; @GuardedBy("this") private long successfulRequests; @@ -81,12 +81,12 @@ public class ExchangeClient private final Executor pageBufferClientCallbackExecutor; private final TaskFailureListener taskFailureListener; - // ExchangeClientStatus.mergeWith assumes all clients have the same bufferCapacity. + // DirectExchangeClientStatus.mergeWith assumes all clients have the same bufferCapacity. // Please change that method accordingly when this assumption becomes not true. - public ExchangeClient( + public DirectExchangeClient( String selfAddress, DataIntegrityVerification dataIntegrityVerification, - ExchangeClientBuffer buffer, + DirectExchangeBuffer buffer, DataSize maxResponseSize, int concurrentRequestMultiplier, Duration maxErrorDuration, @@ -111,7 +111,7 @@ public ExchangeClient( this.taskFailureListener = requireNonNull(taskFailureListener, "taskFailureListener is null"); } - public ExchangeClientStatus getStatus() + public DirectExchangeClientStatus getStatus() { // The stats created by this method is only for diagnostics. // It does not guarantee a consistent view between different exchange clients. @@ -122,7 +122,7 @@ public ExchangeClientStatus getStatus() } List pageBufferClientStatus = pageBufferClientStatusBuilder.build(); synchronized (this) { - return new ExchangeClientStatus( + return new DirectExchangeClientStatus( buffer.getRetainedSizeInBytes(), buffer.getMaxRetainedSizeInBytes(), averageBytesPerRequest, @@ -338,20 +338,20 @@ public boolean addPages(HttpPageBufferClient client, List pages) { requireNonNull(client, "client is null"); requireNonNull(pages, "pages is null"); - return ExchangeClient.this.addPages(client, pages); + return DirectExchangeClient.this.addPages(client, pages); } @Override public void requestComplete(HttpPageBufferClient client) { requireNonNull(client, "client is null"); - ExchangeClient.this.requestComplete(client); + DirectExchangeClient.this.requestComplete(client); } @Override public void clientFinished(HttpPageBufferClient client) { - ExchangeClient.this.clientFinished(client); + DirectExchangeClient.this.clientFinished(client); } @Override @@ -359,7 +359,7 @@ public void clientFailed(HttpPageBufferClient client, Throwable cause) { requireNonNull(client, "client is null"); requireNonNull(cause, "cause is null"); - ExchangeClient.this.clientFailed(client, cause); + DirectExchangeClient.this.clientFailed(client, cause); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/ExchangeClientConfig.java b/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClientConfig.java similarity index 81% rename from core/trino-main/src/main/java/io/trino/operator/ExchangeClientConfig.java rename to core/trino-main/src/main/java/io/trino/operator/DirectExchangeClientConfig.java index 589c35687ff3..4f0165ab5811 100644 --- a/core/trino-main/src/main/java/io/trino/operator/ExchangeClientConfig.java +++ b/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClientConfig.java @@ -26,7 +26,7 @@ import java.util.concurrent.TimeUnit; -public class ExchangeClientConfig +public class DirectExchangeClientConfig { private DataSize maxBufferSize = DataSize.of(32, Unit.MEGABYTE); private int concurrentRequestMultiplier = 3; @@ -43,7 +43,7 @@ public DataSize getMaxBufferSize() } @Config("exchange.max-buffer-size") - public ExchangeClientConfig setMaxBufferSize(DataSize maxBufferSize) + public DirectExchangeClientConfig setMaxBufferSize(DataSize maxBufferSize) { this.maxBufferSize = maxBufferSize; return this; @@ -56,7 +56,7 @@ public int getConcurrentRequestMultiplier() } @Config("exchange.concurrent-request-multiplier") - public ExchangeClientConfig setConcurrentRequestMultiplier(int concurrentRequestMultiplier) + public DirectExchangeClientConfig setConcurrentRequestMultiplier(int concurrentRequestMultiplier) { this.concurrentRequestMultiplier = concurrentRequestMultiplier; return this; @@ -70,7 +70,7 @@ public Duration getMinErrorDuration() @Deprecated @Config("exchange.min-error-duration") - public ExchangeClientConfig setMinErrorDuration(Duration minErrorDuration) + public DirectExchangeClientConfig setMinErrorDuration(Duration minErrorDuration) { return this; } @@ -83,7 +83,7 @@ public Duration getMaxErrorDuration() } @Config("exchange.max-error-duration") - public ExchangeClientConfig setMaxErrorDuration(Duration maxErrorDuration) + public DirectExchangeClientConfig setMaxErrorDuration(Duration maxErrorDuration) { this.maxErrorDuration = maxErrorDuration; return this; @@ -97,7 +97,7 @@ public DataSize getMaxResponseSize() } @Config("exchange.max-response-size") - public ExchangeClientConfig setMaxResponseSize(DataSize maxResponseSize) + public DirectExchangeClientConfig setMaxResponseSize(DataSize maxResponseSize) { this.maxResponseSize = maxResponseSize; return this; @@ -110,7 +110,7 @@ public int getClientThreads() } @Config("exchange.client-threads") - public ExchangeClientConfig setClientThreads(int clientThreads) + public DirectExchangeClientConfig setClientThreads(int clientThreads) { this.clientThreads = clientThreads; return this; @@ -123,7 +123,7 @@ public int getPageBufferClientMaxCallbackThreads() } @Config("exchange.page-buffer-client.max-callback-threads") - public ExchangeClientConfig setPageBufferClientMaxCallbackThreads(int pageBufferClientMaxCallbackThreads) + public DirectExchangeClientConfig setPageBufferClientMaxCallbackThreads(int pageBufferClientMaxCallbackThreads) { this.pageBufferClientMaxCallbackThreads = pageBufferClientMaxCallbackThreads; return this; @@ -135,7 +135,7 @@ public boolean isAcknowledgePages() } @Config("exchange.acknowledge-pages") - public ExchangeClientConfig setAcknowledgePages(boolean acknowledgePages) + public DirectExchangeClientConfig setAcknowledgePages(boolean acknowledgePages) { this.acknowledgePages = acknowledgePages; return this; diff --git a/core/trino-main/src/main/java/io/trino/operator/ExchangeClientFactory.java b/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClientFactory.java similarity index 91% rename from core/trino-main/src/main/java/io/trino/operator/ExchangeClientFactory.java rename to core/trino-main/src/main/java/io/trino/operator/DirectExchangeClientFactory.java index 58eafe70cef9..806e4072c976 100644 --- a/core/trino-main/src/main/java/io/trino/operator/ExchangeClientFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClientFactory.java @@ -37,8 +37,8 @@ import static java.util.Objects.requireNonNull; import static java.util.concurrent.Executors.newFixedThreadPool; -public class ExchangeClientFactory - implements ExchangeClientSupplier +public class DirectExchangeClientFactory + implements DirectExchangeClientSupplier { private final NodeInfo nodeInfo; private final DataIntegrityVerification dataIntegrityVerification; @@ -53,10 +53,10 @@ public class ExchangeClientFactory private final ExecutorService pageBufferClientCallbackExecutor; @Inject - public ExchangeClientFactory( + public DirectExchangeClientFactory( NodeInfo nodeInfo, FeaturesConfig featuresConfig, - ExchangeClientConfig config, + DirectExchangeClientConfig config, @ForExchange HttpClient httpClient, @ForExchange ScheduledExecutorService scheduler) { @@ -73,7 +73,7 @@ public ExchangeClientFactory( scheduler); } - public ExchangeClientFactory( + public DirectExchangeClientFactory( NodeInfo nodeInfo, DataIntegrityVerification dataIntegrityVerification, DataSize maxBufferedBytes, @@ -123,22 +123,22 @@ public ThreadPoolExecutorMBean getExecutor() } @Override - public ExchangeClient get(LocalMemoryContext systemMemoryContext, TaskFailureListener taskFailureListener, RetryPolicy retryPolicy) + public DirectExchangeClient get(LocalMemoryContext systemMemoryContext, TaskFailureListener taskFailureListener, RetryPolicy retryPolicy) { - ExchangeClientBuffer buffer; + DirectExchangeBuffer buffer; switch (retryPolicy) { case TASK: case QUERY: - buffer = new DeduplicationExchangeClientBuffer(scheduler, maxBufferedBytes, retryPolicy); + buffer = new DeduplicatingDirectExchangeBuffer(scheduler, maxBufferedBytes, retryPolicy); break; case NONE: - buffer = new StreamingExchangeClientBuffer(scheduler, maxBufferedBytes); + buffer = new StreamingDirectExchangeBuffer(scheduler, maxBufferedBytes); break; default: throw new IllegalArgumentException("unexpected retry policy: " + retryPolicy); } - return new ExchangeClient( + return new DirectExchangeClient( nodeInfo.getExternalAddress(), dataIntegrityVerification, buffer, diff --git a/core/trino-main/src/main/java/io/trino/operator/ExchangeClientStatus.java b/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClientStatus.java similarity index 94% rename from core/trino-main/src/main/java/io/trino/operator/ExchangeClientStatus.java rename to core/trino-main/src/main/java/io/trino/operator/DirectExchangeClientStatus.java index b0266c0c0f1d..97fe42b67732 100644 --- a/core/trino-main/src/main/java/io/trino/operator/ExchangeClientStatus.java +++ b/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClientStatus.java @@ -23,8 +23,8 @@ import static com.google.common.base.MoreObjects.toStringHelper; import static java.util.Objects.requireNonNull; -public class ExchangeClientStatus - implements Mergeable, OperatorInfo +public class DirectExchangeClientStatus + implements Mergeable, OperatorInfo { private final long bufferedBytes; private final long maxBufferedBytes; @@ -35,7 +35,7 @@ public class ExchangeClientStatus private final List pageBufferClientStatuses; @JsonCreator - public ExchangeClientStatus( + public DirectExchangeClientStatus( @JsonProperty("bufferedBytes") long bufferedBytes, @JsonProperty("maxBufferedBytes") long maxBufferedBytes, @JsonProperty("averageBytesPerRequest") long averageBytesPerRequest, @@ -116,9 +116,9 @@ public String toString() } @Override - public ExchangeClientStatus mergeWith(ExchangeClientStatus other) + public DirectExchangeClientStatus mergeWith(DirectExchangeClientStatus other) { - return new ExchangeClientStatus( + return new DirectExchangeClientStatus( (bufferedBytes + other.bufferedBytes) / 2, // this is correct as long as all clients have the same buffer size (capacity) Math.max(maxBufferedBytes, other.maxBufferedBytes), mergeAvgs(averageBytesPerRequest, successfulRequestsCount, other.averageBytesPerRequest, other.successfulRequestsCount), diff --git a/core/trino-main/src/main/java/io/trino/operator/ExchangeClientSupplier.java b/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClientSupplier.java similarity index 79% rename from core/trino-main/src/main/java/io/trino/operator/ExchangeClientSupplier.java rename to core/trino-main/src/main/java/io/trino/operator/DirectExchangeClientSupplier.java index 74176f88a267..a74f5a3e28db 100644 --- a/core/trino-main/src/main/java/io/trino/operator/ExchangeClientSupplier.java +++ b/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClientSupplier.java @@ -16,7 +16,7 @@ import io.trino.execution.TaskFailureListener; import io.trino.memory.context.LocalMemoryContext; -public interface ExchangeClientSupplier +public interface DirectExchangeClientSupplier { - ExchangeClient get(LocalMemoryContext systemMemoryContext, TaskFailureListener taskFailureListener, RetryPolicy retryPolicy); + DirectExchangeClient get(LocalMemoryContext systemMemoryContext, TaskFailureListener taskFailureListener, RetryPolicy retryPolicy); } diff --git a/core/trino-main/src/main/java/io/trino/operator/Driver.java b/core/trino-main/src/main/java/io/trino/operator/Driver.java index 44c6fee62960..5fbd2bba6009 100644 --- a/core/trino-main/src/main/java/io/trino/operator/Driver.java +++ b/core/trino-main/src/main/java/io/trino/operator/Driver.java @@ -24,7 +24,7 @@ import io.airlift.log.Logger; import io.airlift.units.Duration; import io.trino.execution.ScheduledSplit; -import io.trino.execution.TaskSource; +import io.trino.execution.SplitAssignment; import io.trino.metadata.Split; import io.trino.spi.Page; import io.trino.spi.TrinoException; @@ -75,11 +75,11 @@ public class Driver private final Optional deleteOperator; private final Optional updateOperator; - // This variable acts as a staging area. When new splits (encapsulated in TaskSource) are + // This variable acts as a staging area. When new splits (encapsulated in SplitAssignment) are // provided to a Driver, the Driver will not process them right away. Instead, the splits are // added to this staging area. This staging area will be drained asynchronously. That's when // the new splits get processed. - private final AtomicReference pendingTaskSourceUpdates = new AtomicReference<>(); + private final AtomicReference pendingSplitAssignmentUpdates = new AtomicReference<>(); private final Map> revokingOperators = new HashMap<>(); private final AtomicReference state = new AtomicReference<>(State.ALIVE); @@ -87,7 +87,7 @@ public class Driver private final DriverLock exclusiveLock = new DriverLock(); @GuardedBy("exclusiveLock") - private TaskSource currentTaskSource; + private SplitAssignment currentSplitAssignment; private final AtomicReference> driverBlockedFuture = new AtomicReference<>(); @@ -147,7 +147,7 @@ else if (operator instanceof UpdateOperator) { this.deleteOperator = deleteOperator; this.updateOperator = updateOperator; - currentTaskSource = sourceOperator.map(operator -> new TaskSource(operator.getSourceId(), ImmutableSet.of(), false)).orElse(null); + currentSplitAssignment = sourceOperator.map(operator -> new SplitAssignment(operator.getSourceId(), ImmutableSet.of(), false)).orElse(null); // initially the driverBlockedFuture is not blocked (it is completed) SettableFuture future = SettableFuture.create(); future.set(null); @@ -209,15 +209,15 @@ private boolean isFinishedInternal() return finished; } - public void updateSource(TaskSource sourceUpdate) + public void updateSplitAssignment(SplitAssignment splitAssignment) { - checkLockNotHeld("Cannot update sources while holding the driver lock"); + checkLockNotHeld("Cannot update assignments while holding the driver lock"); checkArgument( - sourceOperator.isPresent() && sourceOperator.get().getSourceId().equals(sourceUpdate.getPlanNodeId()), - "sourceUpdate is for a plan node that is different from this Driver's source node"); + sourceOperator.isPresent() && sourceOperator.get().getSourceId().equals(splitAssignment.getPlanNodeId()), + "splitAssignment is for a plan node that is different from this Driver's source node"); // stage the new updates - pendingTaskSourceUpdates.updateAndGet(current -> current == null ? sourceUpdate : current.update(sourceUpdate)); + pendingSplitAssignmentUpdates.updateAndGet(current -> current == null ? splitAssignment : current.update(splitAssignment)); // attempt to get the lock and process the updates we staged above // updates will be processed in close if and only if we got the lock @@ -234,21 +234,21 @@ private void processNewSources() return; } - TaskSource sourceUpdate = pendingTaskSourceUpdates.getAndSet(null); - if (sourceUpdate == null) { + SplitAssignment splitAssignment = pendingSplitAssignmentUpdates.getAndSet(null); + if (splitAssignment == null) { return; } - // merge the current source and the specified source update - TaskSource newSource = currentTaskSource.update(sourceUpdate); + // merge the current assignment and the specified assignment + SplitAssignment newAssignment = currentSplitAssignment.update(splitAssignment); // if the update contains no new data, just return - if (newSource == currentTaskSource) { + if (newAssignment == currentSplitAssignment) { return; } // determine new splits to add - Set newSplits = Sets.difference(newSource.getSplits(), currentTaskSource.getSplits()); + Set newSplits = Sets.difference(newAssignment.getSplits(), currentSplitAssignment.getSplits()); // add new splits SourceOperator sourceOperator = this.sourceOperator.orElseThrow(VerifyException::new); @@ -261,11 +261,11 @@ private void processNewSources() } // set no more splits - if (newSource.isNoMoreSplits()) { + if (newAssignment.isNoMoreSplits()) { sourceOperator.noMoreSplits(); } - currentTaskSource = newSource; + currentSplitAssignment = newAssignment; } public ListenableFuture processFor(Duration duration) @@ -698,12 +698,12 @@ private Optional tryWithLock(long timeout, TimeUnit unit, Supplier tas } } - // If there are more source updates available, attempt to reacquire the lock and process them. - // This can happen if new sources are added while we're holding the lock here doing work. + // If there are more assignment updates available, attempt to reacquire the lock and process them. + // This can happen if new assignments are added while we're holding the lock here doing work. // NOTE: this is separate duplicate code to make debugging lock reacquisition easier // The first condition is for processing the pending updates if this driver is still ALIVE // The second condition is to destroy the driver if the state is NEED_DESTRUCTION - while (((pendingTaskSourceUpdates.get() != null && state.get() == State.ALIVE) || state.get() == State.NEED_DESTRUCTION) + while (((pendingSplitAssignmentUpdates.get() != null && state.get() == State.ALIVE) || state.get() == State.NEED_DESTRUCTION) && exclusiveLock.tryLock()) { try { try { diff --git a/core/trino-main/src/main/java/io/trino/operator/ExchangeOperator.java b/core/trino-main/src/main/java/io/trino/operator/ExchangeOperator.java index 0893c08d2648..9a3cf88e7211 100644 --- a/core/trino-main/src/main/java/io/trino/operator/ExchangeOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/ExchangeOperator.java @@ -14,21 +14,37 @@ package io.trino.operator; import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; +import io.airlift.slice.Slice; import io.trino.connector.CatalogName; +import io.trino.exchange.ExchangeManagerRegistry; +import io.trino.execution.TaskFailureListener; import io.trino.execution.buffer.PagesSerde; import io.trino.execution.buffer.PagesSerdeFactory; import io.trino.execution.buffer.SerializedPage; +import io.trino.memory.context.LocalMemoryContext; import io.trino.metadata.Split; import io.trino.spi.Page; import io.trino.spi.connector.UpdatablePageSource; +import io.trino.spi.exchange.ExchangeManager; +import io.trino.spi.exchange.ExchangeSource; import io.trino.split.RemoteSplit; +import io.trino.split.RemoteSplit.DirectExchangeInput; +import io.trino.split.RemoteSplit.ExternalExchangeInput; import io.trino.sql.planner.plan.PlanNodeId; +import java.io.Closeable; import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.util.concurrent.Futures.immediateVoidFuture; +import static io.airlift.concurrent.MoreFutures.asVoid; +import static io.airlift.concurrent.MoreFutures.toListenableFuture; +import static io.trino.execution.buffer.PagesSerdeUtil.readSerializedPage; import static java.util.Objects.requireNonNull; public class ExchangeOperator @@ -41,24 +57,27 @@ public static class ExchangeOperatorFactory { private final int operatorId; private final PlanNodeId sourceId; - private final ExchangeClientSupplier exchangeClientSupplier; + private final DirectExchangeClientSupplier directExchangeClientSupplier; private final PagesSerdeFactory serdeFactory; private final RetryPolicy retryPolicy; - private ExchangeClient exchangeClient; + private final ExchangeManagerRegistry exchangeManagerRegistry; + private ExchangeDataSource exchangeDataSource; private boolean closed; public ExchangeOperatorFactory( int operatorId, PlanNodeId sourceId, - ExchangeClientSupplier exchangeClientSupplier, + DirectExchangeClientSupplier directExchangeClientSupplier, PagesSerdeFactory serdeFactory, - RetryPolicy retryPolicy) + RetryPolicy retryPolicy, + ExchangeManagerRegistry exchangeManagerRegistry) { this.operatorId = operatorId; this.sourceId = sourceId; - this.exchangeClientSupplier = exchangeClientSupplier; + this.directExchangeClientSupplier = directExchangeClientSupplier; this.serdeFactory = serdeFactory; this.retryPolicy = requireNonNull(retryPolicy, "retryPolicy is null"); + this.exchangeManagerRegistry = requireNonNull(exchangeManagerRegistry, "exchangeManagerRegistry is null"); } @Override @@ -73,15 +92,15 @@ public SourceOperator createOperator(DriverContext driverContext) checkState(!closed, "Factory is already closed"); TaskContext taskContext = driverContext.getPipelineContext().getTaskContext(); OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, sourceId, ExchangeOperator.class.getSimpleName()); - if (exchangeClient == null) { - exchangeClient = exchangeClientSupplier.get(driverContext.getPipelineContext().localSystemMemoryContext(), taskContext::sourceTaskFailed, retryPolicy); + LocalMemoryContext memoryContext = driverContext.getPipelineContext().localSystemMemoryContext(); + if (exchangeDataSource == null) { + exchangeDataSource = new LazyExchangeDataSource(directExchangeClientSupplier, memoryContext, taskContext::sourceTaskFailed, retryPolicy, exchangeManagerRegistry); } - return new ExchangeOperator( operatorContext, sourceId, - serdeFactory.createPagesSerde(), - exchangeClient); + exchangeDataSource, + serdeFactory.createPagesSerde()); } @Override @@ -93,22 +112,22 @@ public void noMoreOperators() private final OperatorContext operatorContext; private final PlanNodeId sourceId; - private final ExchangeClient exchangeClient; + private final ExchangeDataSource exchangeDataSource; private final PagesSerde serde; private ListenableFuture isBlocked = NOT_BLOCKED; public ExchangeOperator( OperatorContext operatorContext, PlanNodeId sourceId, - PagesSerde serde, - ExchangeClient exchangeClient) + ExchangeDataSource exchangeDataSource, + PagesSerde serde) { this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); this.sourceId = requireNonNull(sourceId, "sourceId is null"); - this.exchangeClient = requireNonNull(exchangeClient, "exchangeClient is null"); + this.exchangeDataSource = requireNonNull(exchangeDataSource, "exchangeDataSource is null"); this.serde = requireNonNull(serde, "serde is null"); - operatorContext.setInfoSupplier(exchangeClient::getStatus); + operatorContext.setInfoSupplier(exchangeDataSource::getInfo); } @Override @@ -123,8 +142,7 @@ public Supplier> addSplit(Split split) requireNonNull(split, "split is null"); checkArgument(split.getCatalogName().equals(REMOTE_CONNECTOR_ID), "split is not a remote split"); - RemoteSplit remoteSplit = (RemoteSplit) split.getConnectorSplit(); - exchangeClient.addLocation(remoteSplit.getTaskId(), remoteSplit.getLocation()); + exchangeDataSource.addSplit((RemoteSplit) split.getConnectorSplit()); return Optional::empty; } @@ -132,7 +150,7 @@ public Supplier> addSplit(Split split) @Override public void noMoreSplits() { - exchangeClient.noMoreLocations(); + exchangeDataSource.noMoreSplits(); } @Override @@ -150,15 +168,15 @@ public void finish() @Override public boolean isFinished() { - return exchangeClient.isFinished(); + return exchangeDataSource.isFinished(); } @Override public ListenableFuture isBlocked() { - // Avoid registering a new callback in the ExchangeClient when one is already pending + // Avoid registering a new callback in the data source when one is already pending if (isBlocked.isDone()) { - isBlocked = exchangeClient.isBlocked(); + isBlocked = exchangeDataSource.isBlocked(); if (isBlocked.isDone()) { isBlocked = NOT_BLOCKED; } @@ -181,7 +199,7 @@ public void addInput(Page page) @Override public Page getOutput() { - SerializedPage page = exchangeClient.pollPage(); + SerializedPage page = exchangeDataSource.pollPage(); if (page == null) { return null; } @@ -197,6 +215,290 @@ public Page getOutput() @Override public void close() { - exchangeClient.close(); + exchangeDataSource.close(); + } + + private interface ExchangeDataSource + extends Closeable + { + SerializedPage pollPage(); + + boolean isFinished(); + + ListenableFuture isBlocked(); + + void addSplit(RemoteSplit remoteSplit); + + void noMoreSplits(); + + OperatorInfo getInfo(); + + @Override + void close(); + } + + private static class LazyExchangeDataSource + implements ExchangeDataSource + { + private final DirectExchangeClientSupplier directExchangeClientSupplier; + private final LocalMemoryContext systemMemoryContext; + private final TaskFailureListener taskFailureListener; + private final RetryPolicy retryPolicy; + private final ExchangeManagerRegistry exchangeManagerRegistry; + + private final SettableFuture initializationFuture = SettableFuture.create(); + private final AtomicReference delegate = new AtomicReference<>(); + private final AtomicBoolean closed = new AtomicBoolean(); + + private LazyExchangeDataSource( + DirectExchangeClientSupplier directExchangeClientSupplier, + LocalMemoryContext systemMemoryContext, + TaskFailureListener taskFailureListener, + RetryPolicy retryPolicy, + ExchangeManagerRegistry exchangeManagerRegistry) + { + this.directExchangeClientSupplier = requireNonNull(directExchangeClientSupplier, "directExchangeClientSupplier is null"); + this.systemMemoryContext = requireNonNull(systemMemoryContext, "systemMemoryContext is null"); + this.taskFailureListener = requireNonNull(taskFailureListener, "taskFailureListener is null"); + this.retryPolicy = requireNonNull(retryPolicy, "retryPolicy is null"); + this.exchangeManagerRegistry = requireNonNull(exchangeManagerRegistry, "exchangeManagerRegistry is null"); + } + + @Override + public SerializedPage pollPage() + { + if (closed.get()) { + return null; + } + + ExchangeDataSource dataSource = delegate.get(); + if (dataSource == null) { + return null; + } + return dataSource.pollPage(); + } + + @Override + public boolean isFinished() + { + if (closed.get()) { + return true; + } + ExchangeDataSource dataSource = delegate.get(); + if (dataSource == null) { + return false; + } + return dataSource.isFinished(); + } + + @Override + public ListenableFuture isBlocked() + { + if (closed.get()) { + return immediateVoidFuture(); + } + if (!initializationFuture.isDone()) { + return initializationFuture; + } + ExchangeDataSource dataSource = delegate.get(); + if (dataSource == null) { + return immediateVoidFuture(); + } + return dataSource.isBlocked(); + } + + @Override + public void addSplit(RemoteSplit remoteSplit) + { + SettableFuture future = null; + synchronized (this) { + if (closed.get()) { + return; + } + ExchangeDataSource dataSource = delegate.get(); + if (dataSource == null) { + if (remoteSplit.getExchangeInput() instanceof DirectExchangeInput) { + DirectExchangeClient directExchangeClient = directExchangeClientSupplier.get(systemMemoryContext, taskFailureListener, retryPolicy); + dataSource = new DirectExchangeDataSource(directExchangeClient); + dataSource.addSplit(remoteSplit); + } + else if (remoteSplit.getExchangeInput() instanceof ExternalExchangeInput) { + ExternalExchangeInput externalExchangeInput = (ExternalExchangeInput) remoteSplit.getExchangeInput(); + ExchangeManager exchangeManager = exchangeManagerRegistry.getExchangeManager(); + ExchangeSource exchangeSource = exchangeManager.createSource(externalExchangeInput.getExchangeSourceHandles()); + dataSource = new ExternalExchangeDataSource(exchangeSource, systemMemoryContext); + } + else { + throw new IllegalArgumentException("Unexpected split: " + remoteSplit); + } + delegate.set(dataSource); + future = initializationFuture; + } + else { + dataSource.addSplit(remoteSplit); + } + } + + if (future != null) { + future.set(null); + } + } + + @Override + public synchronized void noMoreSplits() + { + if (closed.get()) { + return; + } + ExchangeDataSource dataSource = delegate.get(); + if (dataSource != null) { + dataSource.noMoreSplits(); + } + else { + close(); + } + } + + @Override + public OperatorInfo getInfo() + { + ExchangeDataSource dataSource = delegate.get(); + if (dataSource == null) { + return null; + } + return dataSource.getInfo(); + } + + @Override + public void close() + { + SettableFuture future; + synchronized (this) { + if (!closed.compareAndSet(false, true)) { + return; + } + ExchangeDataSource dataSource = delegate.get(); + if (dataSource != null) { + dataSource.close(); + } + future = initializationFuture; + } + future.set(null); + } + } + + private static class DirectExchangeDataSource + implements ExchangeDataSource + { + private final DirectExchangeClient directExchangeClient; + + private DirectExchangeDataSource(DirectExchangeClient directExchangeClient) + { + this.directExchangeClient = requireNonNull(directExchangeClient, "directExchangeClient is null"); + } + + @Override + public SerializedPage pollPage() + { + return directExchangeClient.pollPage(); + } + + @Override + public boolean isFinished() + { + return directExchangeClient.isFinished(); + } + + @Override + public ListenableFuture isBlocked() + { + return directExchangeClient.isBlocked(); + } + + @Override + public void addSplit(RemoteSplit remoteSplit) + { + DirectExchangeInput taskInput = (DirectExchangeInput) remoteSplit.getExchangeInput(); + directExchangeClient.addLocation(taskInput.getTaskId(), taskInput.getLocation()); + } + + @Override + public void noMoreSplits() + { + directExchangeClient.noMoreLocations(); + } + + @Override + public OperatorInfo getInfo() + { + return directExchangeClient.getStatus(); + } + + @Override + public void close() + { + directExchangeClient.close(); + } + } + + private static class ExternalExchangeDataSource + implements ExchangeDataSource + { + private final ExchangeSource exchangeSource; + private final LocalMemoryContext systemMemoryContext; + + private ExternalExchangeDataSource(ExchangeSource exchangeSource, LocalMemoryContext systemMemoryContext) + { + this.exchangeSource = requireNonNull(exchangeSource, "exchangeSource is null"); + this.systemMemoryContext = requireNonNull(systemMemoryContext, "systemMemoryContext is null"); + } + + @Override + public SerializedPage pollPage() + { + Slice data = exchangeSource.read(); + systemMemoryContext.setBytes(exchangeSource.getSystemMemoryUsage()); + if (data == null) { + return null; + } + // TODO: Avoid extra memory copy + return readSerializedPage(data.getInput()); + } + + @Override + public boolean isFinished() + { + return exchangeSource.isFinished(); + } + + @Override + public ListenableFuture isBlocked() + { + return asVoid(toListenableFuture(exchangeSource.isBlocked())); + } + + @Override + public void addSplit(RemoteSplit remoteSplit) + { + // ignore + } + + @Override + public void noMoreSplits() + { + // ignore + } + + @Override + public OperatorInfo getInfo() + { + return null; + } + + @Override + public void close() + { + exchangeSource.close(); + } } } diff --git a/core/trino-main/src/main/java/io/trino/operator/MergeOperator.java b/core/trino-main/src/main/java/io/trino/operator/MergeOperator.java index 8a32faae1469..a84de0940f65 100644 --- a/core/trino-main/src/main/java/io/trino/operator/MergeOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/MergeOperator.java @@ -24,6 +24,7 @@ import io.trino.spi.connector.UpdatablePageSource; import io.trino.spi.type.Type; import io.trino.split.RemoteSplit; +import io.trino.split.RemoteSplit.DirectExchangeInput; import io.trino.sql.gen.OrderingCompiler; import io.trino.sql.planner.plan.PlanNodeId; @@ -48,7 +49,7 @@ public static class MergeOperatorFactory { private final int operatorId; private final PlanNodeId sourceId; - private final ExchangeClientSupplier exchangeClientSupplier; + private final DirectExchangeClientSupplier directExchangeClientSupplier; private final PagesSerdeFactory serdeFactory; private final List types; private final List outputChannels; @@ -61,7 +62,7 @@ public static class MergeOperatorFactory public MergeOperatorFactory( int operatorId, PlanNodeId sourceId, - ExchangeClientSupplier exchangeClientSupplier, + DirectExchangeClientSupplier directExchangeClientSupplier, PagesSerdeFactory serdeFactory, OrderingCompiler orderingCompiler, List types, @@ -71,7 +72,7 @@ public MergeOperatorFactory( { this.operatorId = operatorId; this.sourceId = requireNonNull(sourceId, "sourceId is null"); - this.exchangeClientSupplier = requireNonNull(exchangeClientSupplier, "exchangeClientSupplier is null"); + this.directExchangeClientSupplier = requireNonNull(directExchangeClientSupplier, "directExchangeClientSupplier is null"); this.serdeFactory = requireNonNull(serdeFactory, "serdeFactory is null"); this.types = requireNonNull(types, "types is null"); this.outputChannels = requireNonNull(outputChannels, "outputChannels is null"); @@ -96,7 +97,7 @@ public SourceOperator createOperator(DriverContext driverContext) return new MergeOperator( operatorContext, sourceId, - exchangeClientSupplier, + directExchangeClientSupplier, serdeFactory.createPagesSerde(), orderingCompiler.compilePageWithPositionComparator(types, sortChannels, sortOrder), outputChannels, @@ -112,7 +113,7 @@ public void noMoreOperators() private final OperatorContext operatorContext; private final PlanNodeId sourceId; - private final ExchangeClientSupplier exchangeClientSupplier; + private final DirectExchangeClientSupplier directExchangeClientSupplier; private final PagesSerde pagesSerde; private final PageWithPositionComparator comparator; private final List outputChannels; @@ -129,7 +130,7 @@ public void noMoreOperators() public MergeOperator( OperatorContext operatorContext, PlanNodeId sourceId, - ExchangeClientSupplier exchangeClientSupplier, + DirectExchangeClientSupplier directExchangeClientSupplier, PagesSerde pagesSerde, PageWithPositionComparator comparator, List outputChannels, @@ -137,7 +138,7 @@ public MergeOperator( { this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); this.sourceId = requireNonNull(sourceId, "sourceId is null"); - this.exchangeClientSupplier = requireNonNull(exchangeClientSupplier, "exchangeClientSupplier is null"); + this.directExchangeClientSupplier = requireNonNull(directExchangeClientSupplier, "directExchangeClientSupplier is null"); this.pagesSerde = requireNonNull(pagesSerde, "pagesSerde is null"); this.comparator = requireNonNull(comparator, "comparator is null"); this.outputChannels = requireNonNull(outputChannels, "outputChannels is null"); @@ -158,11 +159,12 @@ public Supplier> addSplit(Split split) checkState(!blockedOnSplits.isDone(), "noMoreSplits has been called already"); TaskContext taskContext = operatorContext.getDriverContext().getPipelineContext().getTaskContext(); - ExchangeClient exchangeClient = closer.register(exchangeClientSupplier.get(operatorContext.localSystemMemoryContext(), taskContext::sourceTaskFailed, RetryPolicy.NONE)); + DirectExchangeClient client = closer.register(directExchangeClientSupplier.get(operatorContext.localSystemMemoryContext(), taskContext::sourceTaskFailed, RetryPolicy.NONE)); RemoteSplit remoteSplit = (RemoteSplit) split.getConnectorSplit(); - exchangeClient.addLocation(remoteSplit.getTaskId(), remoteSplit.getLocation()); - exchangeClient.noMoreLocations(); - pageProducers.add(exchangeClient.pages() + DirectExchangeInput taskInput = (DirectExchangeInput) remoteSplit.getExchangeInput(); + client.addLocation(taskInput.getTaskId(), taskInput.getLocation()); + client.noMoreLocations(); + pageProducers.add(client.pages() .map(serializedPage -> { operatorContext.recordNetworkInput(serializedPage.getSizeInBytes(), serializedPage.getPositionCount()); return pagesSerde.deserialize(serializedPage); diff --git a/core/trino-main/src/main/java/io/trino/operator/OperatorInfo.java b/core/trino-main/src/main/java/io/trino/operator/OperatorInfo.java index a5c392d78f8e..03b9de018a66 100644 --- a/core/trino-main/src/main/java/io/trino/operator/OperatorInfo.java +++ b/core/trino-main/src/main/java/io/trino/operator/OperatorInfo.java @@ -24,7 +24,7 @@ use = JsonTypeInfo.Id.NAME, property = "@type") @JsonSubTypes({ - @JsonSubTypes.Type(value = ExchangeClientStatus.class, name = "exchangeClientStatus"), + @JsonSubTypes.Type(value = DirectExchangeClientStatus.class, name = "directExchangeClientStatus"), @JsonSubTypes.Type(value = LocalExchangeBufferInfo.class, name = "localExchangeBuffer"), @JsonSubTypes.Type(value = TableFinishInfo.class, name = "tableFinish"), @JsonSubTypes.Type(value = SplitOperatorInfo.class, name = "splitOperator"), diff --git a/core/trino-main/src/main/java/io/trino/operator/StreamingExchangeClientBuffer.java b/core/trino-main/src/main/java/io/trino/operator/StreamingDirectExchangeBuffer.java similarity index 97% rename from core/trino-main/src/main/java/io/trino/operator/StreamingExchangeClientBuffer.java rename to core/trino-main/src/main/java/io/trino/operator/StreamingDirectExchangeBuffer.java index c027ff395a41..64a3d3a15b5e 100644 --- a/core/trino-main/src/main/java/io/trino/operator/StreamingExchangeClientBuffer.java +++ b/core/trino-main/src/main/java/io/trino/operator/StreamingDirectExchangeBuffer.java @@ -36,8 +36,8 @@ import static java.lang.Math.max; import static java.util.Objects.requireNonNull; -public class StreamingExchangeClientBuffer - implements ExchangeClientBuffer +public class StreamingDirectExchangeBuffer + implements DirectExchangeBuffer { private final Executor executor; private final long bufferCapacityInBytes; @@ -59,7 +59,7 @@ public class StreamingExchangeClientBuffer @GuardedBy("this") private boolean closed; - public StreamingExchangeClientBuffer(Executor executor, DataSize bufferCapacity) + public StreamingDirectExchangeBuffer(Executor executor, DataSize bufferCapacity) { this.executor = requireNonNull(executor, "executor is null"); this.bufferCapacityInBytes = requireNonNull(bufferCapacity, "bufferCapacity is null").toBytes(); diff --git a/core/trino-main/src/main/java/io/trino/operator/index/IndexLoader.java b/core/trino-main/src/main/java/io/trino/operator/index/IndexLoader.java index 72925b147146..826780540841 100644 --- a/core/trino-main/src/main/java/io/trino/operator/index/IndexLoader.java +++ b/core/trino-main/src/main/java/io/trino/operator/index/IndexLoader.java @@ -20,7 +20,7 @@ import io.trino.connector.CatalogName; import io.trino.execution.Lifespan; import io.trino.execution.ScheduledSplit; -import io.trino.execution.TaskSource; +import io.trino.execution.SplitAssignment; import io.trino.metadata.Split; import io.trino.operator.Driver; import io.trino.operator.DriverFactory; @@ -242,7 +242,7 @@ public IndexedData streamIndexDataForSingleKey(UpdateRequest updateRequest) PageRecordSet pageRecordSet = new PageRecordSet(keyTypes, indexKeyTuple); PlanNodeId planNodeId = driverFactory.getSourceId().get(); ScheduledSplit split = new ScheduledSplit(0, planNodeId, new Split(INDEX_CONNECTOR_ID, new IndexSplit(pageRecordSet), Lifespan.taskWide())); - driver.updateSource(new TaskSource(planNodeId, ImmutableSet.of(split), true)); + driver.updateSplitAssignment(new SplitAssignment(planNodeId, ImmutableSet.of(split), true)); return new StreamingIndexedData(outputTypes, keyEqualOperators, indexKeyTuple, pageBuffer, driver); } @@ -338,7 +338,7 @@ public boolean load(List requests) try (Driver driver = driverFactory.createDriver(pipelineContext.addDriverContext())) { PlanNodeId sourcePlanNodeId = driverFactory.getSourceId().get(); ScheduledSplit split = new ScheduledSplit(0, sourcePlanNodeId, new Split(INDEX_CONNECTOR_ID, new IndexSplit(recordSetForLookupSource), Lifespan.taskWide())); - driver.updateSource(new TaskSource(sourcePlanNodeId, ImmutableSet.of(split), true)); + driver.updateSplitAssignment(new SplitAssignment(sourcePlanNodeId, ImmutableSet.of(split), true)); while (!driver.isFinished()) { ListenableFuture process = driver.process(); checkState(process.isDone(), "Driver should never block"); diff --git a/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java b/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java index d4a58e0d227a..e39986c786f9 100644 --- a/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java +++ b/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java @@ -64,6 +64,8 @@ import io.trino.execution.scheduler.ExecutionPolicy; import io.trino.execution.scheduler.PhasedExecutionPolicy; import io.trino.execution.scheduler.SplitSchedulerStats; +import io.trino.execution.scheduler.StageTaskSourceFactory; +import io.trino.execution.scheduler.TaskSourceFactory; import io.trino.failuredetector.FailureDetectorModule; import io.trino.memory.ClusterMemoryManager; import io.trino.memory.ForMemoryManager; @@ -282,6 +284,8 @@ protected void setup(Binder binder) binder.bind(SplitSchedulerStats.class).in(Scopes.SINGLETON); newExporter(binder).export(SplitSchedulerStats.class).withGeneratedName(); + binder.bind(TaskSourceFactory.class).to(StageTaskSourceFactory.class).in(Scopes.SINGLETON); + MapBinder executionPolicyBinder = newMapBinder(binder, String.class, ExecutionPolicy.class); executionPolicyBinder.addBinding("all-at-once").to(AllAtOnceExecutionPolicy.class); executionPolicyBinder.addBinding("phased").to(PhasedExecutionPolicy.class); diff --git a/core/trino-main/src/main/java/io/trino/server/PluginManager.java b/core/trino-main/src/main/java/io/trino/server/PluginManager.java index 9391d21f6b18..4864172b3562 100644 --- a/core/trino-main/src/main/java/io/trino/server/PluginManager.java +++ b/core/trino-main/src/main/java/io/trino/server/PluginManager.java @@ -17,6 +17,7 @@ import io.airlift.log.Logger; import io.trino.connector.ConnectorManager; import io.trino.eventlistener.EventListenerManager; +import io.trino.exchange.ExchangeManagerRegistry; import io.trino.execution.resourcegroups.ResourceGroupManager; import io.trino.metadata.MetadataManager; import io.trino.security.AccessControlManager; @@ -29,6 +30,7 @@ import io.trino.spi.classloader.ThreadContextClassLoader; import io.trino.spi.connector.ConnectorFactory; import io.trino.spi.eventlistener.EventListenerFactory; +import io.trino.spi.exchange.ExchangeManagerFactory; import io.trino.spi.resourcegroups.ResourceGroupConfigurationManagerFactory; import io.trino.spi.security.CertificateAuthenticatorFactory; import io.trino.spi.security.GroupProviderFactory; @@ -76,6 +78,7 @@ public class PluginManager private final Optional headerAuthenticatorManager; private final EventListenerManager eventListenerManager; private final GroupProviderManager groupProviderManager; + private final ExchangeManagerRegistry exchangeManagerRegistry; private final SessionPropertyDefaults sessionPropertyDefaults; private final AtomicBoolean pluginsLoading = new AtomicBoolean(); private final AtomicBoolean pluginsLoaded = new AtomicBoolean(); @@ -92,7 +95,8 @@ public PluginManager( Optional headerAuthenticatorManager, EventListenerManager eventListenerManager, GroupProviderManager groupProviderManager, - SessionPropertyDefaults sessionPropertyDefaults) + SessionPropertyDefaults sessionPropertyDefaults, + ExchangeManagerRegistry exchangeManagerRegistry) { this.pluginsProvider = requireNonNull(pluginsProvider, "pluginsProvider is null"); this.connectorManager = requireNonNull(connectorManager, "connectorManager is null"); @@ -105,6 +109,7 @@ public PluginManager( this.eventListenerManager = requireNonNull(eventListenerManager, "eventListenerManager is null"); this.groupProviderManager = requireNonNull(groupProviderManager, "groupProviderManager is null"); this.sessionPropertyDefaults = requireNonNull(sessionPropertyDefaults, "sessionPropertyDefaults is null"); + this.exchangeManagerRegistry = requireNonNull(exchangeManagerRegistry, "exchangeManagerRegistry is null"); } public void loadPlugins() @@ -226,6 +231,11 @@ private void installPluginInternal(Plugin plugin, Supplier duplicat log.info("Registering group provider %s", groupProviderFactory.getName()); groupProviderManager.addGroupProviderFactory(groupProviderFactory); } + + for (ExchangeManagerFactory exchangeManagerFactory : plugin.getExchangeManagerFactories()) { + log.info("Registering exchange manager %s", exchangeManagerFactory.getName()); + exchangeManagerRegistry.addExchangeManagerFactory(exchangeManagerFactory); + } } private PluginClassLoader createClassLoader(List urls) diff --git a/core/trino-main/src/main/java/io/trino/server/Server.java b/core/trino-main/src/main/java/io/trino/server/Server.java index a480b719e4d5..9e5b992a16c3 100644 --- a/core/trino-main/src/main/java/io/trino/server/Server.java +++ b/core/trino-main/src/main/java/io/trino/server/Server.java @@ -40,6 +40,8 @@ import io.trino.client.NodeVersion; import io.trino.eventlistener.EventListenerManager; import io.trino.eventlistener.EventListenerModule; +import io.trino.exchange.ExchangeManagerModule; +import io.trino.exchange.ExchangeManagerRegistry; import io.trino.execution.resourcegroups.ResourceGroupManager; import io.trino.execution.warnings.WarningCollectorModule; import io.trino.metadata.Catalog; @@ -104,6 +106,7 @@ private void doStart(String trinoVersion) new ServerSecurityModule(), new AccessControlModule(), new EventListenerModule(), + new ExchangeManagerModule(), new CoordinatorDiscoveryModule(), new ServerMainModule(trinoVersion), new GracefulShutdownModule(), @@ -134,6 +137,7 @@ private void doStart(String trinoVersion) .ifPresent(PasswordAuthenticatorManager::loadPasswordAuthenticator); injector.getInstance(EventListenerManager.class).loadEventListeners(); injector.getInstance(GroupProviderManager.class).loadConfiguredGroupProvider(); + injector.getInstance(ExchangeManagerRegistry.class).loadExchangeManager(); injector.getInstance(CertificateAuthenticatorManager.class).loadCertificateAuthenticator(); injector.getInstance(optionalKey(HeaderAuthenticatorManager.class)) .ifPresent(HeaderAuthenticatorManager::loadHeaderAuthenticator); diff --git a/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java b/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java index 4bd96f692b31..137acd7ca7f7 100644 --- a/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java +++ b/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java @@ -81,9 +81,9 @@ import io.trino.metadata.SystemSecurityMetadata; import io.trino.metadata.TableProceduresPropertyManager; import io.trino.metadata.TablePropertyManager; -import io.trino.operator.ExchangeClientConfig; -import io.trino.operator.ExchangeClientFactory; -import io.trino.operator.ExchangeClientSupplier; +import io.trino.operator.DirectExchangeClientConfig; +import io.trino.operator.DirectExchangeClientFactory; +import io.trino.operator.DirectExchangeClientSupplier; import io.trino.operator.ForExchange; import io.trino.operator.GroupByHashPageIndexerFactory; import io.trino.operator.OperatorFactories; @@ -322,7 +322,7 @@ protected void setup(Binder binder) jaxrsBinder(binder).bind(PagesResponseWriter.class); // exchange client - binder.bind(ExchangeClientSupplier.class).to(ExchangeClientFactory.class).in(Scopes.SINGLETON); + binder.bind(DirectExchangeClientSupplier.class).to(DirectExchangeClientFactory.class).in(Scopes.SINGLETON); httpClientBinder(binder).bindHttpClient("exchange", ForExchange.class) .withTracing() .withFilter(GenerateTraceTokenRequestFilter.class) @@ -333,7 +333,7 @@ protected void setup(Binder binder) config.setMaxContentLength(DataSize.of(32, MEGABYTE)); }); - configBinder(binder).bindConfig(ExchangeClientConfig.class); + configBinder(binder).bindConfig(DirectExchangeClientConfig.class); binder.bind(ExchangeExecutionMBean.class).in(Scopes.SINGLETON); newExporter(binder).export(ExchangeExecutionMBean.class).withGeneratedName(); @@ -489,7 +489,7 @@ public static Executor createStartupExecutor(ServerConfig config) @Provides @Singleton @ForExchange - public static ScheduledExecutorService createExchangeExecutor(ExchangeClientConfig config) + public static ScheduledExecutorService createExchangeExecutor(DirectExchangeClientConfig config) { return newScheduledThreadPool(config.getClientThreads(), daemonThreadsNamed("exchange-client-%s")); } diff --git a/core/trino-main/src/main/java/io/trino/server/TaskResource.java b/core/trino-main/src/main/java/io/trino/server/TaskResource.java index c932bf35876b..4ca8c96f5581 100644 --- a/core/trino-main/src/main/java/io/trino/server/TaskResource.java +++ b/core/trino-main/src/main/java/io/trino/server/TaskResource.java @@ -151,7 +151,7 @@ public void createOrUpdateTask( TaskInfo taskInfo = taskManager.updateTask(session, taskId, taskUpdateRequest.getFragment(), - taskUpdateRequest.getSources(), + taskUpdateRequest.getSplitAssignments(), taskUpdateRequest.getOutputIds(), taskUpdateRequest.getDynamicFilterDomains()); diff --git a/core/trino-main/src/main/java/io/trino/server/TaskUpdateRequest.java b/core/trino-main/src/main/java/io/trino/server/TaskUpdateRequest.java index 6beaed826af4..b3943312c646 100644 --- a/core/trino-main/src/main/java/io/trino/server/TaskUpdateRequest.java +++ b/core/trino-main/src/main/java/io/trino/server/TaskUpdateRequest.java @@ -17,7 +17,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import io.trino.SessionRepresentation; -import io.trino.execution.TaskSource; +import io.trino.execution.SplitAssignment; import io.trino.execution.buffer.OutputBuffers; import io.trino.spi.predicate.Domain; import io.trino.sql.planner.PlanFragment; @@ -36,7 +36,7 @@ public class TaskUpdateRequest // extraCredentials is stored separately from SessionRepresentation to avoid being leaked private final Map extraCredentials; private final Optional fragment; - private final List sources; + private final List splitAssignments; private final OutputBuffers outputIds; private final Map dynamicFilterDomains; @@ -45,21 +45,21 @@ public TaskUpdateRequest( @JsonProperty("session") SessionRepresentation session, @JsonProperty("extraCredentials") Map extraCredentials, @JsonProperty("fragment") Optional fragment, - @JsonProperty("sources") List sources, + @JsonProperty("splitAssignments") List splitAssignments, @JsonProperty("outputIds") OutputBuffers outputIds, @JsonProperty("dynamicFilterDomains") Map dynamicFilterDomains) { requireNonNull(session, "session is null"); requireNonNull(extraCredentials, "extraCredentials is null"); requireNonNull(fragment, "fragment is null"); - requireNonNull(sources, "sources is null"); + requireNonNull(splitAssignments, "splitAssignments is null"); requireNonNull(outputIds, "outputIds is null"); requireNonNull(dynamicFilterDomains, "dynamicFilterDomains is null"); this.session = session; this.extraCredentials = extraCredentials; this.fragment = fragment; - this.sources = ImmutableList.copyOf(sources); + this.splitAssignments = ImmutableList.copyOf(splitAssignments); this.outputIds = outputIds; this.dynamicFilterDomains = dynamicFilterDomains; } @@ -83,9 +83,9 @@ public Optional getFragment() } @JsonProperty - public List getSources() + public List getSplitAssignments() { - return sources; + return splitAssignments; } @JsonProperty @@ -107,7 +107,7 @@ public String toString() .add("session", session) .add("extraCredentials", extraCredentials.keySet()) .add("fragment", fragment) - .add("sources", sources) + .add("splitAssignments", splitAssignments) .add("outputIds", outputIds) .add("dynamicFilterDomains", dynamicFilterDomains) .toString(); diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/ExecutingStatementResource.java b/core/trino-main/src/main/java/io/trino/server/protocol/ExecutingStatementResource.java index 7063413c618b..91b6333f39c7 100644 --- a/core/trino-main/src/main/java/io/trino/server/protocol/ExecutingStatementResource.java +++ b/core/trino-main/src/main/java/io/trino/server/protocol/ExecutingStatementResource.java @@ -24,7 +24,7 @@ import io.trino.client.ProtocolHeaders; import io.trino.client.QueryResults; import io.trino.execution.QueryManager; -import io.trino.operator.ExchangeClientSupplier; +import io.trino.operator.DirectExchangeClientSupplier; import io.trino.server.ForStatementResource; import io.trino.server.ServerConfig; import io.trino.server.security.ResourceSecurity; @@ -80,7 +80,7 @@ public class ExecutingStatementResource private static final DataSize MAX_TARGET_RESULT_SIZE = DataSize.of(128, MEGABYTE); private final QueryManager queryManager; - private final ExchangeClientSupplier exchangeClientSupplier; + private final DirectExchangeClientSupplier directExchangeClientSupplier; private final BlockEncodingSerde blockEncodingSerde; private final QueryInfoUrlFactory queryInfoUrlFactory; private final BoundedExecutor responseExecutor; @@ -93,7 +93,7 @@ public class ExecutingStatementResource @Inject public ExecutingStatementResource( QueryManager queryManager, - ExchangeClientSupplier exchangeClientSupplier, + DirectExchangeClientSupplier directExchangeClientSupplier, BlockEncodingSerde blockEncodingSerde, QueryInfoUrlFactory queryInfoUrlTemplate, @ForStatementResource BoundedExecutor responseExecutor, @@ -101,7 +101,7 @@ public ExecutingStatementResource( ServerConfig serverConfig) { this.queryManager = requireNonNull(queryManager, "queryManager is null"); - this.exchangeClientSupplier = requireNonNull(exchangeClientSupplier, "exchangeClientSupplier is null"); + this.directExchangeClientSupplier = requireNonNull(directExchangeClientSupplier, "directExchangeClientSupplier is null"); this.blockEncodingSerde = requireNonNull(blockEncodingSerde, "blockEncodingSerde is null"); this.queryInfoUrlFactory = requireNonNull(queryInfoUrlTemplate, "queryInfoUrlTemplate is null"); this.responseExecutor = requireNonNull(responseExecutor, "responseExecutor is null"); @@ -183,7 +183,7 @@ protected Query getQuery(QueryId queryId, String slug, long token) querySlug, queryManager, queryInfoUrlFactory.getQueryInfoUrl(queryId), - exchangeClientSupplier, + directExchangeClientSupplier, responseExecutor, timeoutExecutor, blockEncodingSerde)); diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/Query.java b/core/trino-main/src/main/java/io/trino/server/protocol/Query.java index 9e6c53d70a96..2209143f1f68 100644 --- a/core/trino-main/src/main/java/io/trino/server/protocol/Query.java +++ b/core/trino-main/src/main/java/io/trino/server/protocol/Query.java @@ -49,8 +49,8 @@ import io.trino.execution.buffer.PagesSerdeFactory; import io.trino.execution.buffer.SerializedPage; import io.trino.memory.context.SimpleLocalMemoryContext; -import io.trino.operator.ExchangeClient; -import io.trino.operator.ExchangeClientSupplier; +import io.trino.operator.DirectExchangeClient; +import io.trino.operator.DirectExchangeClientSupplier; import io.trino.spi.ErrorCode; import io.trino.spi.Page; import io.trino.spi.QueryId; @@ -130,7 +130,7 @@ class Query private final Optional queryInfoUrl; @GuardedBy("this") - private final ExchangeClient exchangeClient; + private final DirectExchangeClient exchangeClient; private final Executor resultsProcessorExecutor; private final ScheduledExecutorService timeoutExecutor; @@ -194,12 +194,12 @@ public static Query create( Slug slug, QueryManager queryManager, Optional queryInfoUrl, - ExchangeClientSupplier exchangeClientSupplier, + DirectExchangeClientSupplier directExchangeClientSupplier, Executor dataProcessorExecutor, ScheduledExecutorService timeoutExecutor, BlockEncodingSerde blockEncodingSerde) { - ExchangeClient exchangeClient = exchangeClientSupplier.get( + DirectExchangeClient exchangeClient = directExchangeClientSupplier.get( new SimpleLocalMemoryContext(newSimpleAggregatedMemoryContext(), Query.class.getSimpleName()), queryManager::outputTaskFailed, getRetryPolicy(session)); @@ -223,7 +223,7 @@ private Query( Slug slug, QueryManager queryManager, Optional queryInfoUrl, - ExchangeClient exchangeClient, + DirectExchangeClient exchangeClient, Executor resultsProcessorExecutor, ScheduledExecutorService timeoutExecutor, BlockEncodingSerde blockEncodingSerde) diff --git a/core/trino-main/src/main/java/io/trino/server/remotetask/HttpRemoteTask.java b/core/trino-main/src/main/java/io/trino/server/remotetask/HttpRemoteTask.java index fedf562dce19..cab29dfa30c7 100644 --- a/core/trino-main/src/main/java/io/trino/server/remotetask/HttpRemoteTask.java +++ b/core/trino-main/src/main/java/io/trino/server/remotetask/HttpRemoteTask.java @@ -39,10 +39,10 @@ import io.trino.execution.PartitionedSplitsInfo; import io.trino.execution.RemoteTask; import io.trino.execution.ScheduledSplit; +import io.trino.execution.SplitAssignment; import io.trino.execution.StateMachine.StateChangeListener; import io.trino.execution.TaskId; import io.trino.execution.TaskInfo; -import io.trino.execution.TaskSource; import io.trino.execution.TaskState; import io.trino.execution.TaskStatus; import io.trino.execution.buffer.BufferInfo; @@ -528,17 +528,17 @@ private synchronized void updateSplitQueueSpace() } } - private synchronized void processTaskUpdate(TaskInfo newValue, List sources) + private synchronized void processTaskUpdate(TaskInfo newValue, List splitAssignments) { updateTaskInfo(newValue); // remove acknowledged splits, which frees memory - for (TaskSource source : sources) { - PlanNodeId planNodeId = source.getPlanNodeId(); + for (SplitAssignment assignment : splitAssignments) { + PlanNodeId planNodeId = assignment.getPlanNodeId(); boolean isPartitionedSource = planFragment.isPartitionedSources(planNodeId); int removed = 0; long removedWeight = 0; - for (ScheduledSplit split : source.getSplits()) { + for (ScheduledSplit split : assignment.getSplits()) { if (pendingSplits.remove(planNodeId, split)) { if (isPartitionedSource) { removed++; @@ -546,10 +546,10 @@ private synchronized void processTaskUpdate(TaskInfo newValue, List } } } - if (source.isNoMoreSplits()) { + if (assignment.isNoMoreSplits()) { noMoreSplits.put(planNodeId, false); } - for (Lifespan lifespan : source.getNoMoreSplitsForLifespan()) { + for (Lifespan lifespan : assignment.getNoMoreSplitsForLifespan()) { pendingNoMoreSplitsForLifespan.remove(planNodeId, lifespan); } if (isPartitionedSource) { @@ -601,7 +601,7 @@ private synchronized void sendUpdate() return; } - List sources = getSources(); + List splitAssignments = getSplitAssignments(); VersionedDynamicFilterDomains dynamicFilterDomains = outboundDynamicFiltersCollector.acknowledgeAndGetNewDomains(sentDynamicFiltersVersion); // Workers don't need the embedded JSON representation when the fragment is sent @@ -610,7 +610,7 @@ private synchronized void sendUpdate() session.toSessionRepresentation(), session.getIdentity().getExtraCredentials(), fragment, - sources, + splitAssignments, outputBuffers.get(), dynamicFilterDomains.getDynamicFilterDomains()); byte[] taskUpdateRequestJson = taskUpdateRequestCodec.toJsonBytes(updateRequest); @@ -640,32 +640,32 @@ private synchronized void sendUpdate() Futures.addCallback( future, - new SimpleHttpResponseHandler<>(new UpdateResponseHandler(sources, dynamicFilterDomains.getVersion()), request.getUri(), stats), + new SimpleHttpResponseHandler<>(new UpdateResponseHandler(splitAssignments, dynamicFilterDomains.getVersion()), request.getUri(), stats), executor); } - private synchronized List getSources() + private synchronized List getSplitAssignments() { return Stream.concat(planFragment.getPartitionedSourceNodes().stream(), planFragment.getRemoteSourceNodes().stream()) .filter(Objects::nonNull) .map(PlanNode::getId) - .map(this::getSource) + .map(this::getSplitAssignment) .filter(Objects::nonNull) .collect(toImmutableList()); } - private synchronized TaskSource getSource(PlanNodeId planNodeId) + private synchronized SplitAssignment getSplitAssignment(PlanNodeId planNodeId) { Set splits = pendingSplits.get(planNodeId); boolean pendingNoMoreSplits = Boolean.TRUE.equals(this.noMoreSplits.get(planNodeId)); boolean noMoreSplits = this.noMoreSplits.containsKey(planNodeId); Set noMoreSplitsForLifespan = pendingNoMoreSplitsForLifespan.get(planNodeId); - TaskSource element = null; + SplitAssignment assignment = null; if (!splits.isEmpty() || !noMoreSplitsForLifespan.isEmpty() || pendingNoMoreSplits) { - element = new TaskSource(planNodeId, splits, noMoreSplitsForLifespan, noMoreSplits); + assignment = new SplitAssignment(planNodeId, splits, noMoreSplitsForLifespan, noMoreSplits); } - return element; + return assignment; } @Override @@ -859,12 +859,12 @@ public String toString() private class UpdateResponseHandler implements SimpleHttpResponseCallback { - private final List sources; + private final List splitAssignments; private final long currentRequestDynamicFiltersVersion; - private UpdateResponseHandler(List sources, long currentRequestDynamicFiltersVersion) + private UpdateResponseHandler(List splitAssignments, long currentRequestDynamicFiltersVersion) { - this.sources = ImmutableList.copyOf(requireNonNull(sources, "sources is null")); + this.splitAssignments = ImmutableList.copyOf(requireNonNull(splitAssignments, "splitAssignments is null")); this.currentRequestDynamicFiltersVersion = currentRequestDynamicFiltersVersion; } @@ -883,7 +883,7 @@ public void success(TaskInfo value) // Remove dynamic filters which were successfully sent to free up memory outboundDynamicFiltersCollector.acknowledge(currentRequestDynamicFiltersVersion); updateStats(currentRequestStartNanos); - processTaskUpdate(value, sources); + processTaskUpdate(value, splitAssignments); updateErrorTracker.requestSucceeded(); } finally { diff --git a/core/trino-main/src/main/java/io/trino/server/testing/TestingTrinoServer.java b/core/trino-main/src/main/java/io/trino/server/testing/TestingTrinoServer.java index 75e6778eeed2..e11bcbade326 100644 --- a/core/trino-main/src/main/java/io/trino/server/testing/TestingTrinoServer.java +++ b/core/trino-main/src/main/java/io/trino/server/testing/TestingTrinoServer.java @@ -44,6 +44,7 @@ import io.trino.dispatcher.DispatchManager; import io.trino.eventlistener.EventListenerConfig; import io.trino.eventlistener.EventListenerManager; +import io.trino.exchange.ExchangeManagerRegistry; import io.trino.execution.FailureInjector; import io.trino.execution.FailureInjector.InjectedFailureType; import io.trino.execution.QueryInfo; @@ -167,6 +168,7 @@ public static Builder builder() private final MBeanServer mBeanServer; private final boolean coordinator; private final FailureInjector failureInjector; + private final ExchangeManagerRegistry exchangeManagerRegistry; public static class TestShutdownAction implements ShutdownAction @@ -256,6 +258,7 @@ private TestingTrinoServer( binder.bind(ShutdownAction.class).to(TestShutdownAction.class).in(Scopes.SINGLETON); binder.bind(GracefulShutdownHandler.class).in(Scopes.SINGLETON); binder.bind(ProcedureTester.class).in(Scopes.SINGLETON); + binder.bind(ExchangeManagerRegistry.class).in(Scopes.SINGLETON); }); if (discoveryUri.isPresent()) { @@ -328,6 +331,7 @@ private TestingTrinoServer( mBeanServer = injector.getInstance(MBeanServer.class); announcer = injector.getInstance(Announcer.class); failureInjector = injector.getInstance(FailureInjector.class); + exchangeManagerRegistry = injector.getInstance(ExchangeManagerRegistry.class); accessControl.setSystemAccessControls(systemAccessControls); @@ -400,6 +404,11 @@ public CatalogName createCatalog(String catalogName, String connectorName, Map properties) + { + exchangeManagerRegistry.loadExchangeManager(name, properties); + } + public Path getBaseDataDir() { return baseDataDir; diff --git a/core/trino-main/src/main/java/io/trino/split/RemoteSplit.java b/core/trino-main/src/main/java/io/trino/split/RemoteSplit.java index e043fd1a2b22..725f02452106 100644 --- a/core/trino-main/src/main/java/io/trino/split/RemoteSplit.java +++ b/core/trino-main/src/main/java/io/trino/split/RemoteSplit.java @@ -15,10 +15,13 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; import com.google.common.collect.ImmutableList; import io.trino.execution.TaskId; import io.trino.spi.HostAddress; import io.trino.spi.connector.ConnectorSplit; +import io.trino.spi.exchange.ExchangeSourceHandle; import java.net.URI; import java.util.List; @@ -29,26 +32,18 @@ public class RemoteSplit implements ConnectorSplit { - private final TaskId taskId; - private final URI location; + private final ExchangeInput exchangeInput; @JsonCreator - public RemoteSplit(@JsonProperty("taskId") TaskId taskId, @JsonProperty("location") URI location) + public RemoteSplit(@JsonProperty("exchangeInput") ExchangeInput exchangeInput) { - this.taskId = requireNonNull(taskId, "taskId is null"); - this.location = requireNonNull(location, "location is null"); + this.exchangeInput = requireNonNull(exchangeInput, "remoteSplitInput is null"); } @JsonProperty - public TaskId getTaskId() + public ExchangeInput getExchangeInput() { - return taskId; - } - - @JsonProperty - public URI getLocation() - { - return location; + return exchangeInput; } @Override @@ -73,8 +68,78 @@ public List getAddresses() public String toString() { return toStringHelper(this) - .add("taskId", taskId) - .add("location", location) + .add("exchangeInput", exchangeInput) .toString(); } + + @JsonTypeInfo( + use = JsonTypeInfo.Id.NAME, + property = "@type") + @JsonSubTypes({ + @JsonSubTypes.Type(value = DirectExchangeInput.class, name = "direct"), + @JsonSubTypes.Type(value = ExternalExchangeInput.class, name = "external")}) + public interface ExchangeInput + { + } + + public static class DirectExchangeInput + implements ExchangeInput + { + private final TaskId taskId; + private final URI location; + + @JsonCreator + public DirectExchangeInput(@JsonProperty("taskId") TaskId taskId, @JsonProperty("location") URI location) + { + this.taskId = requireNonNull(taskId, "taskId is null"); + this.location = requireNonNull(location, "location is null"); + } + + @JsonProperty + public TaskId getTaskId() + { + return taskId; + } + + @JsonProperty + public URI getLocation() + { + return location; + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("taskId", taskId) + .add("location", location) + .toString(); + } + } + + public static class ExternalExchangeInput + implements ExchangeInput + { + private final List exchangeSourceHandles; + + @JsonCreator + public ExternalExchangeInput(@JsonProperty("exchangeSourceHandles") List exchangeSourceHandles) + { + this.exchangeSourceHandles = ImmutableList.copyOf(requireNonNull(exchangeSourceHandles, "exchangeSourceHandles is null")); + } + + @JsonProperty + public List getExchangeSourceHandles() + { + return exchangeSourceHandles; + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("exchangeSourceHandles", exchangeSourceHandles) + .toString(); + } + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java index 03cbf841a3d3..63a6d27769b7 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java @@ -31,6 +31,7 @@ import io.airlift.units.DataSize; import io.trino.Session; import io.trino.SystemSessionProperties; +import io.trino.exchange.ExchangeManagerRegistry; import io.trino.execution.DynamicFilterConfig; import io.trino.execution.ExplainAnalyzeContext; import io.trino.execution.StageId; @@ -48,11 +49,11 @@ import io.trino.operator.AssignUniqueIdOperator; import io.trino.operator.DeleteOperator.DeleteOperatorFactory; import io.trino.operator.DevNullOperator.DevNullOperatorFactory; +import io.trino.operator.DirectExchangeClientSupplier; import io.trino.operator.DriverFactory; import io.trino.operator.DynamicFilterSourceOperator; import io.trino.operator.DynamicFilterSourceOperator.DynamicFilterSourceOperatorFactory; import io.trino.operator.EnforceSingleRowOperator; -import io.trino.operator.ExchangeClientSupplier; import io.trino.operator.ExchangeOperator.ExchangeOperatorFactory; import io.trino.operator.ExplainAnalyzeOperator.ExplainAnalyzeOperatorFactory; import io.trino.operator.FilterAndProjectOperator; @@ -359,7 +360,7 @@ public class LocalExecutionPlanner private final IndexManager indexManager; private final NodePartitioningManager nodePartitioningManager; private final PageSinkManager pageSinkManager; - private final ExchangeClientSupplier exchangeClientSupplier; + private final DirectExchangeClientSupplier directExchangeClientSupplier; private final ExpressionCompiler expressionCompiler; private final PageFunctionCompiler pageFunctionCompiler; private final JoinFilterFunctionCompiler joinFilterFunctionCompiler; @@ -379,6 +380,7 @@ public class LocalExecutionPlanner private final TypeOperators typeOperators; private final BlockTypeOperators blockTypeOperators; private final TableExecuteContextManager tableExecuteContextManager; + private final ExchangeManagerRegistry exchangeManagerRegistry; @Inject public LocalExecutionPlanner( @@ -389,7 +391,7 @@ public LocalExecutionPlanner( IndexManager indexManager, NodePartitioningManager nodePartitioningManager, PageSinkManager pageSinkManager, - ExchangeClientSupplier exchangeClientSupplier, + DirectExchangeClientSupplier directExchangeClientSupplier, ExpressionCompiler expressionCompiler, PageFunctionCompiler pageFunctionCompiler, JoinFilterFunctionCompiler joinFilterFunctionCompiler, @@ -405,13 +407,14 @@ public LocalExecutionPlanner( DynamicFilterConfig dynamicFilterConfig, TypeOperators typeOperators, BlockTypeOperators blockTypeOperators, - TableExecuteContextManager tableExecuteContextManager) + TableExecuteContextManager tableExecuteContextManager, + ExchangeManagerRegistry exchangeManagerRegistry) { this.explainAnalyzeContext = requireNonNull(explainAnalyzeContext, "explainAnalyzeContext is null"); this.pageSourceProvider = requireNonNull(pageSourceProvider, "pageSourceProvider is null"); this.indexManager = requireNonNull(indexManager, "indexManager is null"); this.nodePartitioningManager = requireNonNull(nodePartitioningManager, "nodePartitioningManager is null"); - this.exchangeClientSupplier = exchangeClientSupplier; + this.directExchangeClientSupplier = directExchangeClientSupplier; this.metadata = requireNonNull(metadata, "metadata is null"); this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); this.pageSinkManager = requireNonNull(pageSinkManager, "pageSinkManager is null"); @@ -434,6 +437,7 @@ public LocalExecutionPlanner( this.typeOperators = requireNonNull(typeOperators, "typeOperators is null"); this.blockTypeOperators = requireNonNull(blockTypeOperators, "blockTypeOperators is null"); this.tableExecuteContextManager = requireNonNull(tableExecuteContextManager, "tableExecuteContextManager is null"); + this.exchangeManagerRegistry = requireNonNull(exchangeManagerRegistry, "exchangeManagerRegistry is null"); } public LocalExecutionPlan plan( @@ -878,7 +882,7 @@ private PhysicalOperation createMergeSource(RemoteSourceNode node, LocalExecutio OperatorFactory operatorFactory = new MergeOperatorFactory( context.getNextOperatorId(), node.getId(), - exchangeClientSupplier, + directExchangeClientSupplier, new PagesSerdeFactory(metadata.getBlockEncodingSerde(), isExchangeCompressionEnabled(session)), orderingCompiler, types, @@ -898,9 +902,10 @@ private PhysicalOperation createRemoteSource(RemoteSourceNode node, LocalExecuti OperatorFactory operatorFactory = new ExchangeOperatorFactory( context.getNextOperatorId(), node.getId(), - exchangeClientSupplier, + directExchangeClientSupplier, new PagesSerdeFactory(metadata.getBlockEncodingSerde(), isExchangeCompressionEnabled(session)), - node.getRetryPolicy()); + node.getRetryPolicy(), + exchangeManagerRegistry); return new PhysicalOperation(operatorFactory, makeLayout(node), context, UNGROUPED_EXECUTION); } diff --git a/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java b/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java index e4756422f39a..e5fe7a62ec39 100644 --- a/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java +++ b/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java @@ -50,6 +50,7 @@ import io.trino.cost.TaskCountEstimator; import io.trino.eventlistener.EventListenerConfig; import io.trino.eventlistener.EventListenerManager; +import io.trino.exchange.ExchangeManagerRegistry; import io.trino.execution.DynamicFilterConfig; import io.trino.execution.FailureInjector.InjectedFailureType; import io.trino.execution.Lifespan; @@ -58,9 +59,9 @@ import io.trino.execution.QueryPreparer; import io.trino.execution.QueryPreparer.PreparedQuery; import io.trino.execution.ScheduledSplit; +import io.trino.execution.SplitAssignment; import io.trino.execution.TableExecuteContextManager; import io.trino.execution.TaskManagerConfig; -import io.trino.execution.TaskSource; import io.trino.execution.resourcegroups.NoOpResourceGroupManager; import io.trino.execution.scheduler.NodeScheduler; import io.trino.execution.scheduler.NodeSchedulerConfig; @@ -248,6 +249,7 @@ public class LocalQueryRunner private final JoinCompiler joinCompiler; private final ConnectorManager connectorManager; private final PluginManager pluginManager; + private final ExchangeManagerRegistry exchangeManagerRegistry; private final TaskManagerConfig taskManagerConfig; private final boolean alwaysRevokeMemory; @@ -347,6 +349,7 @@ private LocalQueryRunner( this.joinFilterFunctionCompiler = new JoinFilterFunctionCompiler(metadata); NodeInfo nodeInfo = new NodeInfo("test"); + HandleResolver handleResolver = new HandleResolver(); this.connectorManager = new ConnectorManager( metadata, catalogManager, @@ -356,7 +359,7 @@ private LocalQueryRunner( indexManager, nodePartitioningManager, pageSinkManager, - new HandleResolver(), + handleResolver, nodeManager, nodeInfo, testingVersionEmbedder(), @@ -380,6 +383,7 @@ private LocalQueryRunner( new TransactionsSystemTable(metadata, transactionManager)), ImmutableSet.of()); + exchangeManagerRegistry = new ExchangeManagerRegistry(handleResolver); this.pluginManager = new PluginManager( (loader, createClassLoader) -> {}, connectorManager, @@ -391,7 +395,8 @@ private LocalQueryRunner( Optional.of(new HeaderAuthenticatorManager(new HeaderAuthenticatorConfig())), eventListenerManager, new GroupProviderManager(), - new SessionPropertyDefaults(nodeInfo, accessControl)); + new SessionPropertyDefaults(nodeInfo, accessControl), + exchangeManagerRegistry); connectorManager.addConnectorFactory(globalSystemConnectorFactory, globalSystemConnectorFactory.getClass()::getClassLoader); connectorManager.createCatalog(GlobalSystemConnector.NAME, GlobalSystemConnector.NAME, ImmutableMap.of()); @@ -747,6 +752,12 @@ public void injectTaskFailure( throw new UnsupportedOperationException("failure injection is not supported"); } + @Override + public void loadExchangeManager(String name, Map properties) + { + exchangeManagerRegistry.loadExchangeManager(name, properties); + } + public List createDrivers(@Language("SQL") String sql, OutputFactory outputFactory, TaskContext taskContext) { return createDrivers(defaultSession, sql, outputFactory, taskContext); @@ -800,7 +811,8 @@ private List createDrivers(Session session, Plan plan, OutputFactory out new DynamicFilterConfig(), typeOperators, blockTypeOperators, - tableExecuteContextManager); + tableExecuteContextManager, + exchangeManagerRegistry); // plan query StageExecutionDescriptor stageExecutionDescriptor = subplan.getFragment().getStageExecutionDescriptor(); @@ -813,8 +825,8 @@ private List createDrivers(Session session, Plan plan, OutputFactory out subplan.getFragment().getPartitionedSources(), outputFactory); - // generate sources - List sources = new ArrayList<>(); + // generate splitAssignments + List splitAssignments = new ArrayList<>(); long sequenceId = 0; for (TableScanNode tableScan : findTableScanNodes(subplan.getFragment().getRoot())) { TableHandle table = tableScan.getTable(); @@ -833,7 +845,7 @@ private List createDrivers(Session session, Plan plan, OutputFactory out } } - sources.add(new TaskSource(tableScan.getId(), scheduledSplits.build(), true)); + splitAssignments.add(new SplitAssignment(tableScan.getId(), scheduledSplits.build(), true)); } // create drivers @@ -852,16 +864,16 @@ private List createDrivers(Session session, Plan plan, OutputFactory out } } - // add sources to the drivers + // add split assignments to the drivers ImmutableSet partitionedSources = ImmutableSet.copyOf(subplan.getFragment().getPartitionedSources()); - for (TaskSource source : sources) { - DriverFactory driverFactory = driverFactoriesBySource.get(source.getPlanNodeId()); + for (SplitAssignment splitAssignment : splitAssignments) { + DriverFactory driverFactory = driverFactoriesBySource.get(splitAssignment.getPlanNodeId()); checkState(driverFactory != null); boolean partitioned = partitionedSources.contains(driverFactory.getSourceId().get()); - for (ScheduledSplit split : source.getSplits()) { + for (ScheduledSplit split : splitAssignment.getSplits()) { DriverContext driverContext = taskContext.addPipelineContext(driverFactory.getPipelineId(), driverFactory.isInputDriver(), driverFactory.isOutputDriver(), partitioned).addDriverContext(); Driver driver = driverFactory.createDriver(driverContext); - driver.updateSource(new TaskSource(split.getPlanNodeId(), ImmutableSet.of(split), true)); + driver.updateSplitAssignment(new SplitAssignment(split.getPlanNodeId(), ImmutableSet.of(split), true)); drivers.add(driver); } } diff --git a/core/trino-main/src/main/java/io/trino/testing/QueryRunner.java b/core/trino-main/src/main/java/io/trino/testing/QueryRunner.java index bcfcda7ccf23..552fead84509 100644 --- a/core/trino-main/src/main/java/io/trino/testing/QueryRunner.java +++ b/core/trino-main/src/main/java/io/trino/testing/QueryRunner.java @@ -98,6 +98,8 @@ void injectTaskFailure( InjectedFailureType injectionType, Optional errorType); + void loadExchangeManager(String name, Map properties); + class MaterializedResultWithPlan { private final MaterializedResult materializedResult; diff --git a/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java b/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java index 8dd5e83d03e5..4bf40c6c8cd2 100644 --- a/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java +++ b/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java @@ -26,6 +26,7 @@ import io.airlift.units.Duration; import io.trino.Session; import io.trino.cost.StatsAndCosts; +import io.trino.exchange.ExchangeManagerRegistry; import io.trino.execution.NodeTaskMap.PartitionedSplitCountTracker; import io.trino.execution.buffer.LazyOutputBuffer; import io.trino.execution.buffer.OutputBuffer; @@ -33,6 +34,7 @@ import io.trino.memory.MemoryPool; import io.trino.memory.QueryContext; import io.trino.memory.context.SimpleLocalMemoryContext; +import io.trino.metadata.HandleResolver; import io.trino.metadata.InternalNode; import io.trino.metadata.Split; import io.trino.operator.TaskContext; @@ -213,7 +215,8 @@ public MockRemoteTask( DataSize.ofBytes(1), DataSize.ofBytes(1), () -> new SimpleLocalMemoryContext(newSimpleAggregatedMemoryContext(), "test"), - () -> {}); + () -> {}, + new ExchangeManagerRegistry(new HandleResolver())); this.fragment = requireNonNull(fragment, "fragment is null"); this.nodeId = requireNonNull(nodeId, "nodeId is null"); diff --git a/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java b/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java index 5c698ee583cd..e718b2c1ac05 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java +++ b/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java @@ -21,12 +21,14 @@ import io.trino.event.SplitMonitor; import io.trino.eventlistener.EventListenerConfig; import io.trino.eventlistener.EventListenerManager; -import io.trino.execution.TestSqlTaskManager.MockExchangeClientSupplier; +import io.trino.exchange.ExchangeManagerRegistry; +import io.trino.execution.TestSqlTaskManager.MockDirectExchangeClientSupplier; import io.trino.execution.buffer.OutputBuffers; import io.trino.execution.scheduler.NodeScheduler; import io.trino.execution.scheduler.NodeSchedulerConfig; import io.trino.execution.scheduler.UniformNodeSelectorFactory; import io.trino.index.IndexManager; +import io.trino.metadata.HandleResolver; import io.trino.metadata.InMemoryNodeManager; import io.trino.metadata.Metadata; import io.trino.metadata.Split; @@ -80,7 +82,7 @@ private TaskTestUtils() {} public static final ScheduledSplit SPLIT = new ScheduledSplit(0, TABLE_SCAN_NODE_ID, new Split(CONNECTOR_ID, TestingSplit.createLocalSplit(), Lifespan.taskWide())); - public static final ImmutableList EMPTY_SOURCES = ImmutableList.of(); + public static final ImmutableList EMPTY_SPLIT_ASSIGNMENTS = ImmutableList.of(); public static final Symbol SYMBOL = new Symbol("column"); @@ -129,7 +131,7 @@ public static LocalExecutionPlanner createTestingPlanner() new IndexManager(), nodePartitioningManager, new PageSinkManager(), - new MockExchangeClientSupplier(), + new MockDirectExchangeClientSupplier(), new ExpressionCompiler(metadata, pageFunctionCompiler), pageFunctionCompiler, new JoinFilterFunctionCompiler(metadata), @@ -151,12 +153,13 @@ public static LocalExecutionPlanner createTestingPlanner() new DynamicFilterConfig(), typeOperators, blockTypeOperators, - new TableExecuteContextManager()); + new TableExecuteContextManager(), + new ExchangeManagerRegistry(new HandleResolver())); } - public static TaskInfo updateTask(SqlTask sqlTask, List taskSources, OutputBuffers outputBuffers) + public static TaskInfo updateTask(SqlTask sqlTask, List splitAssignments, OutputBuffers outputBuffers) { - return sqlTask.updateTask(TEST_SESSION, Optional.of(PLAN_FRAGMENT), taskSources, outputBuffers, ImmutableMap.of()); + return sqlTask.updateTask(TEST_SESSION, Optional.of(PLAN_FRAGMENT), splitAssignments, outputBuffers, ImmutableMap.of()); } public static SplitMonitor createTestSplitMonitor() diff --git a/core/trino-main/src/test/java/io/trino/execution/TestMemoryRevokingScheduler.java b/core/trino-main/src/test/java/io/trino/execution/TestMemoryRevokingScheduler.java index 14eebbb832ea..70a70a0d7c9d 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestMemoryRevokingScheduler.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestMemoryRevokingScheduler.java @@ -21,10 +21,12 @@ import io.airlift.stats.CounterStat; import io.airlift.stats.TestingGcMonitor; import io.airlift.units.DataSize; +import io.trino.exchange.ExchangeManagerRegistry; import io.trino.execution.executor.TaskExecutor; import io.trino.memory.MemoryPool; import io.trino.memory.QueryContext; import io.trino.memory.context.LocalMemoryContext; +import io.trino.metadata.HandleResolver; import io.trino.operator.DriverContext; import io.trino.operator.OperatorContext; import io.trino.operator.PipelineContext; @@ -306,6 +308,7 @@ private SqlTask newSqlTask(QueryId queryId) sqlTask -> {}, DataSize.of(32, MEGABYTE), DataSize.of(200, MEGABYTE), + new ExchangeManagerRegistry(new HandleResolver()), new CounterStat()); } diff --git a/core/trino-main/src/test/java/io/trino/execution/TestSqlTask.java b/core/trino-main/src/test/java/io/trino/execution/TestSqlTask.java index ea962a0437cf..53e3e9c32902 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestSqlTask.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestSqlTask.java @@ -21,6 +21,7 @@ import io.airlift.stats.CounterStat; import io.airlift.stats.TestingGcMonitor; import io.airlift.units.DataSize; +import io.trino.exchange.ExchangeManagerRegistry; import io.trino.execution.buffer.BufferResult; import io.trino.execution.buffer.BufferState; import io.trino.execution.buffer.OutputBuffers; @@ -28,6 +29,7 @@ import io.trino.execution.executor.TaskExecutor; import io.trino.memory.MemoryPool; import io.trino.memory.QueryContext; +import io.trino.metadata.HandleResolver; import io.trino.operator.TaskContext; import io.trino.spi.QueryId; import io.trino.spi.memory.MemoryPoolId; @@ -51,7 +53,7 @@ import static io.trino.execution.DynamicFiltersCollector.INITIAL_DYNAMIC_FILTERS_VERSION; import static io.trino.execution.SqlTask.createSqlTask; import static io.trino.execution.TaskStatus.STARTING_VERSION; -import static io.trino.execution.TaskTestUtils.EMPTY_SOURCES; +import static io.trino.execution.TaskTestUtils.EMPTY_SPLIT_ASSIGNMENTS; import static io.trino.execution.TaskTestUtils.PLAN_FRAGMENT; import static io.trino.execution.TaskTestUtils.SPLIT; import static io.trino.execution.TaskTestUtils.TABLE_SCAN_NODE_ID; @@ -129,7 +131,7 @@ public void testEmptyQuery() taskInfo = sqlTask.updateTask(TEST_SESSION, Optional.of(PLAN_FRAGMENT), - ImmutableList.of(new TaskSource(TABLE_SCAN_NODE_ID, ImmutableSet.of(), true)), + ImmutableList.of(new SplitAssignment(TABLE_SCAN_NODE_ID, ImmutableSet.of(), true)), createInitialEmptyOutputBuffers(PARTITIONED) .withNoMoreBufferIds(), ImmutableMap.of()); @@ -149,7 +151,7 @@ public void testSimpleQuery() assertEquals(sqlTask.getTaskStatus().getVersion(), STARTING_VERSION); sqlTask.updateTask(TEST_SESSION, Optional.of(PLAN_FRAGMENT), - ImmutableList.of(new TaskSource(TABLE_SCAN_NODE_ID, ImmutableSet.of(SPLIT), true)), + ImmutableList.of(new SplitAssignment(TABLE_SCAN_NODE_ID, ImmutableSet.of(SPLIT), true)), createInitialEmptyOutputBuffers(PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds(), ImmutableMap.of()); @@ -222,7 +224,7 @@ public void testAbort() assertEquals(sqlTask.getTaskStatus().getVersion(), STARTING_VERSION); sqlTask.updateTask(TEST_SESSION, Optional.of(PLAN_FRAGMENT), - ImmutableList.of(new TaskSource(TABLE_SCAN_NODE_ID, ImmutableSet.of(SPLIT), true)), + ImmutableList.of(new SplitAssignment(TABLE_SCAN_NODE_ID, ImmutableSet.of(SPLIT), true)), createInitialEmptyOutputBuffers(PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds(), ImmutableMap.of()); @@ -246,13 +248,13 @@ public void testBufferCloseOnFinish() SqlTask sqlTask = createInitialTask(); OutputBuffers outputBuffers = createInitialEmptyOutputBuffers(PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds(); - updateTask(sqlTask, EMPTY_SOURCES, outputBuffers); + updateTask(sqlTask, EMPTY_SPLIT_ASSIGNMENTS, outputBuffers); ListenableFuture bufferResult = sqlTask.getTaskResults(OUT, 0, DataSize.of(1, MEGABYTE)); assertFalse(bufferResult.isDone()); // close the sources (no splits will ever be added) - updateTask(sqlTask, ImmutableList.of(new TaskSource(TABLE_SCAN_NODE_ID, ImmutableSet.of(), true)), outputBuffers); + updateTask(sqlTask, ImmutableList.of(new SplitAssignment(TABLE_SCAN_NODE_ID, ImmutableSet.of(), true)), outputBuffers); // finish the task by calling abort on it sqlTask.abortTaskResults(OUT); @@ -272,7 +274,7 @@ public void testBufferCloseOnCancel() { SqlTask sqlTask = createInitialTask(); - updateTask(sqlTask, EMPTY_SOURCES, createInitialEmptyOutputBuffers(PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds()); + updateTask(sqlTask, EMPTY_SPLIT_ASSIGNMENTS, createInitialEmptyOutputBuffers(PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds()); ListenableFuture bufferResult = sqlTask.getTaskResults(OUT, 0, DataSize.of(1, MEGABYTE)); assertFalse(bufferResult.isDone()); @@ -294,7 +296,7 @@ public void testBufferNotCloseOnFail() { SqlTask sqlTask = createInitialTask(); - updateTask(sqlTask, EMPTY_SOURCES, createInitialEmptyOutputBuffers(PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds()); + updateTask(sqlTask, EMPTY_SPLIT_ASSIGNMENTS, createInitialEmptyOutputBuffers(PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds()); ListenableFuture bufferResult = sqlTask.getTaskResults(OUT, 0, DataSize.of(1, MEGABYTE)); assertFalse(bufferResult.isDone()); @@ -317,7 +319,7 @@ public void testDynamicFilters() SqlTask sqlTask = createInitialTask(); sqlTask.updateTask(TEST_SESSION, Optional.of(PLAN_FRAGMENT), - ImmutableList.of(new TaskSource(TABLE_SCAN_NODE_ID, ImmutableSet.of(SPLIT), false)), + ImmutableList.of(new SplitAssignment(TABLE_SCAN_NODE_ID, ImmutableSet.of(SPLIT), false)), createInitialEmptyOutputBuffers(PARTITIONED) .withBuffer(OUT, 0) .withNoMoreBufferIds(), @@ -364,6 +366,7 @@ private SqlTask createInitialTask() sqlTask -> {}, DataSize.of(32, MEGABYTE), DataSize.of(200, MEGABYTE), + new ExchangeManagerRegistry(new HandleResolver()), new CounterStat()); } } diff --git a/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskExecution.java b/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskExecution.java index ea460581b070..7ebc51f03092 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskExecution.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskExecution.java @@ -183,8 +183,8 @@ public void testSimple(PipelineExecutionStrategy executionStrategy) switch (executionStrategy) { case UNGROUPED_EXECUTION: - // add source for pipeline - sqlTaskExecution.addSources(ImmutableList.of(new TaskSource( + // add assignment for pipeline + sqlTaskExecution.addSplitAssignments(ImmutableList.of(new SplitAssignment( TABLE_SCAN_NODE_ID, ImmutableSet.of(newScheduledSplit(0, TABLE_SCAN_NODE_ID, Lifespan.taskWide(), 100000, 123)), false))); @@ -195,8 +195,8 @@ public void testSimple(PipelineExecutionStrategy executionStrategy) // * operatorFactory will be closed even though operator can't execute // * completedDriverGroups will NOT include the newly scheduled driver group while pause is in place testingScanOperatorFactory.getPauser().pause(); - // add source for pipeline, mark as no more splits - sqlTaskExecution.addSources(ImmutableList.of(new TaskSource( + // add assignment for pipeline, mark as no more splits + sqlTaskExecution.addSplitAssignments(ImmutableList.of(new SplitAssignment( TABLE_SCAN_NODE_ID, ImmutableSet.of( newScheduledSplit(1, TABLE_SCAN_NODE_ID, Lifespan.taskWide(), 200000, 300), @@ -214,8 +214,8 @@ public void testSimple(PipelineExecutionStrategy executionStrategy) break; case GROUPED_EXECUTION: - // add source for pipeline (driver group [1, 5]), mark driver group [1] as noMoreSplits - sqlTaskExecution.addSources(ImmutableList.of(new TaskSource( + // add assignment for pipeline (driver group [1, 5]), mark driver group [1] as noMoreSplits + sqlTaskExecution.addSplitAssignments(ImmutableList.of(new SplitAssignment( TABLE_SCAN_NODE_ID, ImmutableSet.of( newScheduledSplit(0, TABLE_SCAN_NODE_ID, Lifespan.driverGroup(1), 0, 1), @@ -233,8 +233,8 @@ public void testSimple(PipelineExecutionStrategy executionStrategy) // * operatorFactory will be closed even though operator can't execute // * completedDriverGroups will NOT include the newly scheduled driver group while pause is in place testingScanOperatorFactory.getPauser().pause(); - // add source for pipeline (driver group [5]), mark driver group [5] as noMoreSplits - sqlTaskExecution.addSources(ImmutableList.of(new TaskSource( + // add assignment for pipeline (driver group [5]), mark driver group [5] as noMoreSplits + sqlTaskExecution.addSplitAssignments(ImmutableList.of(new SplitAssignment( TABLE_SCAN_NODE_ID, ImmutableSet.of(newScheduledSplit(2, TABLE_SCAN_NODE_ID, Lifespan.driverGroup(5), 200000, 300)), ImmutableSet.of(Lifespan.driverGroup(5)), @@ -252,8 +252,8 @@ public void testSimple(PipelineExecutionStrategy executionStrategy) // pause operator execution to make sure that testingScanOperatorFactory.getPauser().pause(); - // add source for pipeline (driver group [7]), mark pipeline as noMoreSplits without explicitly marking driver group 7 - sqlTaskExecution.addSources(ImmutableList.of(new TaskSource( + // add assignment for pipeline (driver group [7]), mark pipeline as noMoreSplits without explicitly marking driver group 7 + sqlTaskExecution.addSplitAssignments(ImmutableList.of(new SplitAssignment( TABLE_SCAN_NODE_ID, ImmutableSet.of( newScheduledSplit(3, TABLE_SCAN_NODE_ID, Lifespan.driverGroup(7), 300000, 45), @@ -341,12 +341,12 @@ public void testComplex(PipelineExecutionStrategy executionStrategy) // The following behaviors are tested: // * DriverFactory are marked as noMoreDriver/Operator for particular lifespans as soon as they can be: // * immediately, if the pipeline has task lifecycle (ungrouped and unpartitioned). - // * when TaskSource containing the lifespan is encountered, if the pipeline has driver group lifecycle (grouped and unpartitioned). - // * when TaskSource indicate that no more splits will be produced for the plan node (and plan nodes that schedule before it + // * when SplitAssignment containing the lifespan is encountered, if the pipeline has driver group lifecycle (grouped and unpartitioned). + // * when SplitAssignment indicate that no more splits will be produced for the plan node (and plan nodes that schedule before it // due to phased scheduling) and lifespan combination, if the pipeline has split lifecycle (partitioned). // * DriverFactory are marked as noMoreDriver/Operator as soon as they can be: // * immediately, if the pipeline has task lifecycle (ungrouped and unpartitioned). - // * when TaskSource indicate that will no more splits, otherwise. + // * when SplitAssignment indicate that will no more splits, otherwise. // * Driver groups are marked as completed as soon as they should be: // * when there are no active driver, and all DriverFactory for the lifespan (across all pipelines) are marked as completed. // * Rows are produced as soon as they should be: @@ -439,14 +439,14 @@ public void testComplex(PipelineExecutionStrategy executionStrategy) waitUntilEquals(buildOperatorFactoryA::isOverallNoMoreOperators, true, ASSERT_WAIT_TIMEOUT); waitUntilEquals(buildOperatorFactoryC::isOverallNoMoreOperators, true, ASSERT_WAIT_TIMEOUT); - // add source for pipeline 2, and mark as no more splits - sqlTaskExecution.addSources(ImmutableList.of(new TaskSource( + // add assignment for pipeline 2, and mark as no more splits + sqlTaskExecution.addSplitAssignments(ImmutableList.of(new SplitAssignment( scan2NodeId, ImmutableSet.of( newScheduledSplit(0, scan2NodeId, Lifespan.taskWide(), 100000, 1), newScheduledSplit(1, scan2NodeId, Lifespan.taskWide(), 300000, 2)), false))); - sqlTaskExecution.addSources(ImmutableList.of(new TaskSource( + sqlTaskExecution.addSplitAssignments(ImmutableList.of(new SplitAssignment( scan2NodeId, ImmutableSet.of(newScheduledSplit(2, scan2NodeId, Lifespan.taskWide(), 300000, 2)), true))); @@ -459,8 +459,8 @@ public void testComplex(PipelineExecutionStrategy executionStrategy) // * completedDriverGroups will NOT include the newly scheduled driver group while pause is in place scanOperatorFactory0.getPauser().pause(); - // add source for pipeline 0, mark as no more splits - sqlTaskExecution.addSources(ImmutableList.of(new TaskSource( + // add assignment for pipeline 0, mark as no more splits + sqlTaskExecution.addSplitAssignments(ImmutableList.of(new SplitAssignment( scan0NodeId, ImmutableSet.of(newScheduledSplit(3, scan0NodeId, Lifespan.taskWide(), 400000, 100)), true))); @@ -482,8 +482,8 @@ public void testComplex(PipelineExecutionStrategy executionStrategy) // (Unpartitioned ungrouped pipelines can have all driver instances created up front.) waitUntilEquals(buildOperatorFactoryC::isOverallNoMoreOperators, true, ASSERT_WAIT_TIMEOUT); - // add source for pipeline 2 driver group 3, and mark driver group 3 as no more splits - sqlTaskExecution.addSources(ImmutableList.of(new TaskSource( + // add assignment for pipeline 2 driver group 3, and mark driver group 3 as no more splits + sqlTaskExecution.addSplitAssignments(ImmutableList.of(new SplitAssignment( scan2NodeId, ImmutableSet.of( newScheduledSplit(0, scan2NodeId, Lifespan.driverGroup(3), 0, 1), @@ -492,7 +492,7 @@ public void testComplex(PipelineExecutionStrategy executionStrategy) // assert that pipeline 1 driver group [3] will have no more drivers waitUntilEquals(joinOperatorFactoryB::getDriverGroupsWithNoMoreOperators, ImmutableSet.of(Lifespan.driverGroup(3)), ASSERT_WAIT_TIMEOUT); waitUntilEquals(buildOperatorFactoryA::getDriverGroupsWithNoMoreOperators, ImmutableSet.of(Lifespan.driverGroup(3)), ASSERT_WAIT_TIMEOUT); - sqlTaskExecution.addSources(ImmutableList.of(new TaskSource( + sqlTaskExecution.addSplitAssignments(ImmutableList.of(new SplitAssignment( scan2NodeId, ImmutableSet.of(newScheduledSplit(2, scan2NodeId, Lifespan.driverGroup(3), 200000, 2)), ImmutableSet.of(Lifespan.driverGroup(3)), @@ -505,8 +505,8 @@ public void testComplex(PipelineExecutionStrategy executionStrategy) // * completedDriverGroups will NOT include the newly scheduled driver group while pause is in place scanOperatorFactory0.getPauser().pause(); - // add source for pipeline 0 driver group 3, and mark driver group 3 as no more splits - sqlTaskExecution.addSources(ImmutableList.of(new TaskSource( + // add assignment for pipeline 0 driver group 3, and mark driver group 3 as no more splits + sqlTaskExecution.addSplitAssignments(ImmutableList.of(new SplitAssignment( scan0NodeId, ImmutableSet.of(newScheduledSplit(3, scan0NodeId, Lifespan.driverGroup(3), 300000, 10)), ImmutableSet.of(Lifespan.driverGroup(3)), @@ -524,8 +524,8 @@ public void testComplex(PipelineExecutionStrategy executionStrategy) // assert that driver group [3] is fully completed waitUntilEquals(taskContext::getCompletedDriverGroups, ImmutableSet.of(Lifespan.driverGroup(3)), ASSERT_WAIT_TIMEOUT); - // add source for pipeline 2 driver group 7, and mark pipeline as no more splits - sqlTaskExecution.addSources(ImmutableList.of(new TaskSource( + // add assignment for pipeline 2 driver group 7, and mark pipeline as no more splits + sqlTaskExecution.addSplitAssignments(ImmutableList.of(new SplitAssignment( scan2NodeId, ImmutableSet.of(newScheduledSplit(4, scan2NodeId, Lifespan.driverGroup(7), 400000, 2)), ImmutableSet.of(Lifespan.driverGroup(7)), @@ -539,8 +539,8 @@ public void testComplex(PipelineExecutionStrategy executionStrategy) // * completedDriverGroups will NOT include the newly scheduled driver group while pause is in place scanOperatorFactory0.getPauser().pause(); - // add source for pipeline 0 driver group 7, mark pipeline as no more splits - sqlTaskExecution.addSources(ImmutableList.of(new TaskSource( + // add assignment for pipeline 0 driver group 7, mark pipeline as no more splits + sqlTaskExecution.addSplitAssignments(ImmutableList.of(new SplitAssignment( scan0NodeId, ImmutableSet.of(newScheduledSplit(5, scan0NodeId, Lifespan.driverGroup(7), 500000, 1000)), ImmutableSet.of(Lifespan.driverGroup(7)), diff --git a/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskManager.java b/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskManager.java index 05b684583ae5..5aec35b28558 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskManager.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskManager.java @@ -22,6 +22,7 @@ import io.airlift.units.DataSize; import io.airlift.units.DataSize.Unit; import io.airlift.units.Duration; +import io.trino.exchange.ExchangeManagerRegistry; import io.trino.execution.buffer.BufferResult; import io.trino.execution.buffer.BufferState; import io.trino.execution.buffer.OutputBuffers; @@ -31,9 +32,10 @@ import io.trino.memory.NodeMemoryConfig; import io.trino.memory.QueryContext; import io.trino.memory.context.LocalMemoryContext; +import io.trino.metadata.HandleResolver; import io.trino.metadata.InternalNode; -import io.trino.operator.ExchangeClient; -import io.trino.operator.ExchangeClientSupplier; +import io.trino.operator.DirectExchangeClient; +import io.trino.operator.DirectExchangeClientSupplier; import io.trino.operator.RetryPolicy; import io.trino.spi.QueryId; import io.trino.spiller.LocalSpillManager; @@ -266,7 +268,7 @@ public void testSessionPropertyMemoryLimitOverride() .build(), reduceLimitsId, Optional.of(PLAN_FRAGMENT), - ImmutableList.of(new TaskSource(TABLE_SCAN_NODE_ID, ImmutableSet.of(SPLIT), true)), + ImmutableList.of(new SplitAssignment(TABLE_SCAN_NODE_ID, ImmutableSet.of(SPLIT), true)), createInitialEmptyOutputBuffers(PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds(), ImmutableMap.of()); assertTrue(reducesLimitsContext.isMemoryLimitsInitialized()); @@ -281,7 +283,7 @@ public void testSessionPropertyMemoryLimitOverride() .build(), increaseLimitsId, Optional.of(PLAN_FRAGMENT), - ImmutableList.of(new TaskSource(TABLE_SCAN_NODE_ID, ImmutableSet.of(SPLIT), true)), + ImmutableList.of(new SplitAssignment(TABLE_SCAN_NODE_ID, ImmutableSet.of(SPLIT), true)), createInitialEmptyOutputBuffers(PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds(), ImmutableMap.of()); assertTrue(attemptsIncreaseContext.isMemoryLimitsInitialized()); @@ -310,7 +312,8 @@ private SqlTaskManager createSqlTaskManager(TaskManagerConfig taskManagerConfig, nodeMemoryConfig, localSpillManager, new NodeSpillConfig(), - new TestingGcMonitor()); + new TestingGcMonitor(), + new ExchangeManagerRegistry(new HandleResolver())); } private TaskInfo createTask(SqlTaskManager sqlTaskManager, TaskId taskId, ImmutableSet splits, OutputBuffers outputBuffers) @@ -318,7 +321,7 @@ private TaskInfo createTask(SqlTaskManager sqlTaskManager, TaskId taskId, Immuta return sqlTaskManager.updateTask(TEST_SESSION, taskId, Optional.of(PLAN_FRAGMENT), - ImmutableList.of(new TaskSource(TABLE_SCAN_NODE_ID, splits, true)), + ImmutableList.of(new SplitAssignment(TABLE_SCAN_NODE_ID, splits, true)), outputBuffers, ImmutableMap.of()); } @@ -335,11 +338,11 @@ private TaskInfo createTask(SqlTaskManager sqlTaskManager, TaskId taskId, Output ImmutableMap.of()); } - public static class MockExchangeClientSupplier - implements ExchangeClientSupplier + public static class MockDirectExchangeClientSupplier + implements DirectExchangeClientSupplier { @Override - public ExchangeClient get(LocalMemoryContext systemMemoryContext, TaskFailureListener taskFailureListener, RetryPolicy retryPolicy) + public DirectExchangeClient get(LocalMemoryContext systemMemoryContext, TaskFailureListener taskFailureListener, RetryPolicy retryPolicy) { throw new UnsupportedOperationException(); } diff --git a/core/trino-main/src/test/java/io/trino/execution/TestingRemoteTaskFactory.java b/core/trino-main/src/test/java/io/trino/execution/TestingRemoteTaskFactory.java new file mode 100644 index 000000000000..31e2fe8696a4 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/execution/TestingRemoteTaskFactory.java @@ -0,0 +1,299 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution; + +import com.google.common.collect.ArrayListMultimap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Multimap; +import com.google.common.util.concurrent.ListenableFuture; +import io.airlift.units.DataSize; +import io.airlift.units.Duration; +import io.trino.Session; +import io.trino.execution.NodeTaskMap.PartitionedSplitCountTracker; +import io.trino.execution.StateMachine.StateChangeListener; +import io.trino.execution.buffer.BufferState; +import io.trino.execution.buffer.OutputBufferInfo; +import io.trino.execution.buffer.OutputBuffers; +import io.trino.metadata.InternalNode; +import io.trino.metadata.Split; +import io.trino.operator.TaskStats; +import io.trino.sql.planner.PlanFragment; +import io.trino.sql.planner.plan.DynamicFilterId; +import io.trino.sql.planner.plan.PlanNodeId; +import org.joda.time.DateTime; + +import javax.annotation.concurrent.GuardedBy; + +import java.net.URI; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; + +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.Sets.newConcurrentHashSet; +import static com.google.common.util.concurrent.Futures.immediateVoidFuture; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static io.airlift.units.DataSize.Unit.BYTE; +import static io.trino.execution.DynamicFiltersCollector.INITIAL_DYNAMIC_FILTERS_VERSION; +import static io.trino.util.Failures.toFailures; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.MILLISECONDS; + +public class TestingRemoteTaskFactory + implements RemoteTaskFactory +{ + private static final String TASK_INSTANCE_ID = "task-instance-id"; + + private final Map tasks = new HashMap<>(); + + @Override + public synchronized RemoteTask createRemoteTask( + Session session, + TaskId taskId, + InternalNode node, + PlanFragment fragment, + Multimap initialSplits, + OutputBuffers outputBuffers, + PartitionedSplitCountTracker partitionedSplitCountTracker, + Set outboundDynamicFilterIds, + boolean summarizeTaskInfo) + { + TestingRemoteTask task = new TestingRemoteTask(taskId, node.getNodeIdentifier(), fragment); + task.addSplits(initialSplits); + task.setOutputBuffers(outputBuffers); + checkState(tasks.put(taskId, task) == null, "task already exist: %s", taskId); + return task; + } + + public synchronized Map getTasks() + { + return ImmutableMap.copyOf(tasks); + } + + public static class TestingRemoteTask + implements RemoteTask + { + private final TaskStateMachine taskStateMachine; + private final String nodeId; + private final URI location; + private final PlanFragment fragment; + + private final AtomicLong nextTaskStatusVersion = new AtomicLong(TaskStatus.STARTING_VERSION); + + private final AtomicBoolean started = new AtomicBoolean(); + private final Set noMoreSplits = newConcurrentHashSet(); + @GuardedBy("this") + private final Multimap splits = ArrayListMultimap.create(); + @GuardedBy("this") + private OutputBuffers outputBuffers; + + public TestingRemoteTask(TaskId taskId, String nodeId, PlanFragment fragment) + { + this.taskStateMachine = new TaskStateMachine(taskId, directExecutor()); + this.nodeId = requireNonNull(nodeId, "nodeId is null"); + this.location = URI.create("fake://task/" + taskId + "/node/" + nodeId); + this.fragment = requireNonNull(fragment, "fragment is null"); + } + + public PlanFragment getFragment() + { + return fragment; + } + + @Override + public TaskId getTaskId() + { + return taskStateMachine.getTaskId(); + } + + @Override + public String getNodeId() + { + return nodeId; + } + + @Override + public TaskInfo getTaskInfo() + { + return new TaskInfo( + getTaskStatus(), + DateTime.now(), + new OutputBufferInfo( + "TESTING", + BufferState.FINISHED, + false, + false, + 0, + 0, + 0, + 0, + ImmutableList.of()), + ImmutableSet.copyOf(noMoreSplits), + new TaskStats(DateTime.now(), null), + false); + } + + @Override + public TaskStatus getTaskStatus() + { + TaskState state = taskStateMachine.getState(); + List failures = ImmutableList.of(); + if (state == TaskState.FAILED) { + failures = toFailures(taskStateMachine.getFailureCauses()); + } + return new TaskStatus( + taskStateMachine.getTaskId(), + TASK_INSTANCE_ID, + nextTaskStatusVersion.getAndIncrement(), + state, + location, + nodeId, + ImmutableSet.of(), + failures, + 0, + 0, + false, + DataSize.of(0, BYTE), + DataSize.of(0, BYTE), + DataSize.of(0, BYTE), + DataSize.of(0, BYTE), + 0, + new Duration(0, MILLISECONDS), + INITIAL_DYNAMIC_FILTERS_VERSION, + 0, + 0); + } + + @Override + public void start() + { + started.set(true); + } + + public boolean isStarted() + { + return started.get(); + } + + @Override + public synchronized void addSplits(Multimap splits) + { + this.splits.putAll(splits); + } + + public synchronized Multimap getSplits() + { + return ImmutableListMultimap.copyOf(splits); + } + + @Override + public void noMoreSplits(PlanNodeId sourceId) + { + noMoreSplits.add(sourceId); + } + + public Set getNoMoreSplits() + { + return ImmutableSet.copyOf(noMoreSplits); + } + + @Override + public void noMoreSplits(PlanNodeId sourceId, Lifespan lifespan) + { + } + + @Override + public synchronized void setOutputBuffers(OutputBuffers outputBuffers) + { + this.outputBuffers = outputBuffers; + } + + public synchronized OutputBuffers getOutputBuffers() + { + return outputBuffers; + } + + @Override + public void addStateChangeListener(StateChangeListener stateChangeListener) + { + taskStateMachine.addStateChangeListener(newValue -> stateChangeListener.stateChanged(getTaskStatus())); + } + + @Override + public void addFinalTaskInfoListener(StateChangeListener stateChangeListener) + { + AtomicBoolean done = new AtomicBoolean(); + StateChangeListener fireOnceStateChangeListener = state -> { + if (state.isDone() && done.compareAndSet(false, true)) { + stateChangeListener.stateChanged(getTaskInfo()); + } + }; + taskStateMachine.addStateChangeListener(fireOnceStateChangeListener); + fireOnceStateChangeListener.stateChanged(taskStateMachine.getState()); + } + + @Override + public ListenableFuture whenSplitQueueHasSpace(long weightThreshold) + { + return immediateVoidFuture(); + } + + @Override + public void cancel() + { + taskStateMachine.cancel(); + } + + @Override + public void abort() + { + taskStateMachine.abort(); + } + + @Override + public PartitionedSplitsInfo getPartitionedSplitsInfo() + { + return PartitionedSplitsInfo.forZeroSplits(); + } + + @Override + public void fail(Throwable cause) + { + taskStateMachine.failed(cause); + } + + @Override + public PartitionedSplitsInfo getQueuedPartitionedSplitsInfo() + { + return PartitionedSplitsInfo.forZeroSplits(); + } + + public void finish() + { + taskStateMachine.finished(); + } + + @Override + public int getUnacknowledgedPartitionedSplitCount() + { + return 0; + } + } +} diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java new file mode 100644 index 000000000000..96f8179befe5 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java @@ -0,0 +1,564 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.util.concurrent.ListenableFuture; +import io.trino.Session; +import io.trino.client.NodeVersion; +import io.trino.connector.CatalogName; +import io.trino.cost.StatsAndCosts; +import io.trino.execution.Lifespan; +import io.trino.execution.NodeTaskMap; +import io.trino.execution.RemoteTaskFactory; +import io.trino.execution.SqlStage; +import io.trino.execution.StageId; +import io.trino.execution.TaskId; +import io.trino.execution.TaskState; +import io.trino.execution.TestingRemoteTaskFactory; +import io.trino.execution.TestingRemoteTaskFactory.TestingRemoteTask; +import io.trino.execution.scheduler.TestingExchange.TestingExchangeSinkHandle; +import io.trino.execution.scheduler.TestingExchange.TestingExchangeSourceHandle; +import io.trino.execution.scheduler.TestingNodeSelectorFactory.TestingNodeSupplier; +import io.trino.failuredetector.NoOpFailureDetector; +import io.trino.metadata.InternalNode; +import io.trino.metadata.Split; +import io.trino.spi.QueryId; +import io.trino.spi.exchange.Exchange; +import io.trino.spi.predicate.TupleDomain; +import io.trino.sql.planner.Partitioning; +import io.trino.sql.planner.PartitioningScheme; +import io.trino.sql.planner.PlanFragment; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.plan.JoinNode; +import io.trino.sql.planner.plan.PlanFragmentId; +import io.trino.sql.planner.plan.PlanNodeId; +import io.trino.sql.planner.plan.RemoteSourceNode; +import io.trino.sql.planner.plan.TableScanNode; +import io.trino.testing.TestingMetadata.TestingColumnHandle; +import io.trino.util.FinalizerService; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.net.URI; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.google.common.collect.Iterables.cycle; +import static com.google.common.collect.Iterables.limit; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static io.trino.operator.RetryPolicy.TASK; +import static io.trino.operator.StageExecutionDescriptor.ungroupedExecution; +import static io.trino.spi.type.VarcharType.VARCHAR; +import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; +import static io.trino.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION; +import static io.trino.sql.planner.plan.ExchangeNode.Type.REPLICATE; +import static io.trino.sql.planner.plan.JoinNode.DistributionType.REPLICATED; +import static io.trino.sql.planner.plan.JoinNode.Type.INNER; +import static io.trino.testing.TestingHandles.TEST_TABLE_HANDLE; +import static io.trino.testing.TestingSession.testSessionBuilder; +import static io.trino.testing.TestingSplit.createRemoteSplit; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +@Test(singleThreaded = true) +public class TestFaultTolerantStageScheduler +{ + private static final Session SESSION = testSessionBuilder().build(); + + private static final StageId STAGE_ID = new StageId(new QueryId("query"), 0); + private static final PlanFragmentId FRAGMENT_ID = new PlanFragmentId("0"); + private static final PlanFragmentId SOURCE_FRAGMENT_ID_1 = new PlanFragmentId("1"); + private static final PlanFragmentId SOURCE_FRAGMENT_ID_2 = new PlanFragmentId("2"); + private static final PlanNodeId TABLE_SCAN_NODE_ID = new PlanNodeId("table_scan_id"); + + private static final CatalogName CATALOG = new CatalogName("catalog"); + + private static final InternalNode NODE_1 = new InternalNode("node-1", URI.create("local://127.0.0.1:8080"), NodeVersion.UNKNOWN, false); + private static final InternalNode NODE_2 = new InternalNode("node-2", URI.create("local://127.0.0.1:8081"), NodeVersion.UNKNOWN, false); + private static final InternalNode NODE_3 = new InternalNode("node-3", URI.create("local://127.0.0.1:8082"), NodeVersion.UNKNOWN, false); + + private FinalizerService finalizerService; + private NodeTaskMap nodeTaskMap; + + @BeforeClass + public void beforeClass() + { + finalizerService = new FinalizerService(); + finalizerService.start(); + nodeTaskMap = new NodeTaskMap(finalizerService); + } + + @AfterClass(alwaysRun = true) + public void afterClass() + { + nodeTaskMap = null; + if (finalizerService != null) { + finalizerService.destroy(); + finalizerService = null; + } + } + + @Test + public void testHappyPath() + throws Exception + { + TestingRemoteTaskFactory remoteTaskFactory = new TestingRemoteTaskFactory(); + TestingTaskSourceFactory taskSourceFactory = createTaskSourceFactory(5, 2); + TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of( + NODE_1, ImmutableList.of(CATALOG), + NODE_2, ImmutableList.of(CATALOG), + NODE_3, ImmutableList.of(CATALOG))); + + TestingExchange sinkExchange = new TestingExchange(false); + + TestingExchange sourceExchange1 = new TestingExchange(false); + TestingExchange sourceExchange2 = new TestingExchange(false); + + FaultTolerantStageScheduler scheduler = createFaultTolerantTaskScheduler( + remoteTaskFactory, + taskSourceFactory, + createNodeAllocator(nodeSupplier), + TaskLifecycleListener.NO_OP, + Optional.of(sinkExchange), + ImmutableMap.of(SOURCE_FRAGMENT_ID_1, sourceExchange1, SOURCE_FRAGMENT_ID_2, sourceExchange2), + 2); + + ListenableFuture blocked = scheduler.isBlocked(); + assertTrue(blocked.isDone()); + + scheduler.schedule(); + + blocked = scheduler.isBlocked(); + // blocked on first source exchange + assertFalse(blocked.isDone()); + + sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + // still blocked on the second source exchange + assertFalse(blocked.isDone()); + assertFalse(scheduler.isBlocked().isDone()); + + sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + // now unblocked + assertTrue(blocked.isDone()); + assertTrue(scheduler.isBlocked().isDone()); + + scheduler.schedule(); + + blocked = scheduler.isBlocked(); + // blocked on node allocation + assertFalse(blocked.isDone()); + + // not all tasks have been enumerated yet + assertFalse(sinkExchange.isNoMoreSinks()); + + Map tasks = remoteTaskFactory.getTasks(); + // one task per node + assertThat(tasks).hasSize(3); + assertThat(tasks).containsKey(getTaskId(0, 0)); + assertThat(tasks).containsKey(getTaskId(1, 0)); + assertThat(tasks).containsKey(getTaskId(2, 0)); + + TestingRemoteTask task = tasks.get(getTaskId(0, 0)); + // fail task for partition 0 + task.fail(new RuntimeException("some failure")); + + assertTrue(blocked.isDone()); + assertTrue(scheduler.isBlocked().isDone()); + + // schedule more tasks + scheduler.schedule(); + + tasks = remoteTaskFactory.getTasks(); + assertThat(tasks).hasSize(4); + assertThat(tasks).containsKey(getTaskId(3, 0)); + + blocked = scheduler.isBlocked(); + // blocked on task scheduling + assertFalse(blocked.isDone()); + + // finish some task + assertThat(tasks).containsKey(getTaskId(1, 0)); + tasks.get(getTaskId(1, 0)).finish(); + + assertTrue(blocked.isDone()); + assertTrue(scheduler.isBlocked().isDone()); + assertThat(sinkExchange.getFinishedSinkHandles()).contains(new TestingExchangeSinkHandle(1)); + + // this will schedule failed task + scheduler.schedule(); + + blocked = scheduler.isBlocked(); + // blocked on task scheduling + assertFalse(blocked.isDone()); + + tasks = remoteTaskFactory.getTasks(); + assertThat(tasks).hasSize(5); + assertThat(tasks).containsKey(getTaskId(0, 1)); + + // finish some task + tasks = remoteTaskFactory.getTasks(); + assertThat(tasks).containsKey(getTaskId(3, 0)); + tasks.get(getTaskId(3, 0)).finish(); + assertThat(sinkExchange.getFinishedSinkHandles()).contains(new TestingExchangeSinkHandle(1), new TestingExchangeSinkHandle(3)); + + assertTrue(blocked.isDone()); + + // schedule the last task + scheduler.schedule(); + + tasks = remoteTaskFactory.getTasks(); + assertThat(tasks).hasSize(6); + assertThat(tasks).containsKey(getTaskId(4, 0)); + + // not finished yet, will be finished when all tasks succeed + assertFalse(scheduler.isFinished()); + + blocked = scheduler.isBlocked(); + // blocked on task scheduling + assertFalse(blocked.isDone()); + + tasks = remoteTaskFactory.getTasks(); + assertThat(tasks).containsKey(getTaskId(4, 0)); + // finish remaining tasks + tasks.get(getTaskId(0, 1)).finish(); + tasks.get(getTaskId(2, 0)).finish(); + tasks.get(getTaskId(4, 0)).finish(); + + // now it's not blocked and finished + assertTrue(blocked.isDone()); + assertTrue(scheduler.isBlocked().isDone()); + + assertThat(sinkExchange.getFinishedSinkHandles()).contains( + new TestingExchangeSinkHandle(0), + new TestingExchangeSinkHandle(1), + new TestingExchangeSinkHandle(2), + new TestingExchangeSinkHandle(3), + new TestingExchangeSinkHandle(4)); + + assertTrue(scheduler.isFinished()); + } + + @Test + public void testTaskLifecycleListener() + throws Exception + { + TestingRemoteTaskFactory remoteTaskFactory = new TestingRemoteTaskFactory(); + TestingTaskSourceFactory taskSourceFactory = createTaskSourceFactory(2, 1); + TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of( + NODE_1, ImmutableList.of(CATALOG), + NODE_2, ImmutableList.of(CATALOG))); + + TestingTaskLifecycleListener taskLifecycleListener = new TestingTaskLifecycleListener(); + + TestingExchange sourceExchange1 = new TestingExchange(false); + TestingExchange sourceExchange2 = new TestingExchange(false); + + FaultTolerantStageScheduler scheduler = createFaultTolerantTaskScheduler( + remoteTaskFactory, + taskSourceFactory, + createNodeAllocator(nodeSupplier), + taskLifecycleListener, + Optional.empty(), + ImmutableMap.of(SOURCE_FRAGMENT_ID_1, sourceExchange1, SOURCE_FRAGMENT_ID_2, sourceExchange2), + 2); + + sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + assertTrue(scheduler.isBlocked().isDone()); + + scheduler.schedule(); + assertFalse(scheduler.isBlocked().isDone()); + + assertThat(taskLifecycleListener.getTasks().get(FRAGMENT_ID)).contains(getTaskId(0, 0), getTaskId(1, 0)); + + remoteTaskFactory.getTasks().get(getTaskId(0, 0)).fail(new RuntimeException("some exception")); + + assertTrue(scheduler.isBlocked().isDone()); + scheduler.schedule(); + assertFalse(scheduler.isBlocked().isDone()); + + assertThat(taskLifecycleListener.getTasks().get(FRAGMENT_ID)).contains(getTaskId(0, 0), getTaskId(1, 0), getTaskId(0, 1)); + } + + @Test + public void testTaskFailure() + throws Exception + { + TestingRemoteTaskFactory remoteTaskFactory = new TestingRemoteTaskFactory(); + TestingTaskSourceFactory taskSourceFactory = createTaskSourceFactory(3, 1); + TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of( + NODE_1, ImmutableList.of(CATALOG), + NODE_2, ImmutableList.of(CATALOG))); + + TestingExchange sourceExchange1 = new TestingExchange(false); + TestingExchange sourceExchange2 = new TestingExchange(false); + + NodeAllocator nodeAllocator = createNodeAllocator(nodeSupplier); + FaultTolerantStageScheduler scheduler = createFaultTolerantTaskScheduler( + remoteTaskFactory, + taskSourceFactory, + nodeAllocator, + TaskLifecycleListener.NO_OP, + Optional.empty(), + ImmutableMap.of(SOURCE_FRAGMENT_ID_1, sourceExchange1, SOURCE_FRAGMENT_ID_2, sourceExchange2), + 0); + + sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + assertTrue(scheduler.isBlocked().isDone()); + + scheduler.schedule(); + + ListenableFuture blocked = scheduler.isBlocked(); + // waiting on node acquisition + assertFalse(blocked.isDone()); + + ListenableFuture acquireNode1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), Optional.empty())); + ListenableFuture acquireNode2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), Optional.empty())); + + remoteTaskFactory.getTasks().get(getTaskId(0, 0)).fail(new RuntimeException("some failure")); + + assertTrue(blocked.isDone()); + assertTrue(acquireNode1.isDone()); + assertTrue(acquireNode2.isDone()); + + assertThatThrownBy(scheduler::schedule) + .hasMessageContaining("some failure"); + + assertTrue(scheduler.isBlocked().isDone()); + assertFalse(scheduler.isFinished()); + } + + @Test + public void testReportTaskFailure() + throws Exception + { + TestingRemoteTaskFactory remoteTaskFactory = new TestingRemoteTaskFactory(); + TestingTaskSourceFactory taskSourceFactory = createTaskSourceFactory(2, 1); + TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of( + NODE_1, ImmutableList.of(CATALOG), + NODE_2, ImmutableList.of(CATALOG))); + + TestingExchange sourceExchange1 = new TestingExchange(false); + TestingExchange sourceExchange2 = new TestingExchange(false); + + NodeAllocator nodeAllocator = createNodeAllocator(nodeSupplier); + FaultTolerantStageScheduler scheduler = createFaultTolerantTaskScheduler( + remoteTaskFactory, + taskSourceFactory, + nodeAllocator, + TaskLifecycleListener.NO_OP, + Optional.empty(), + ImmutableMap.of(SOURCE_FRAGMENT_ID_1, sourceExchange1, SOURCE_FRAGMENT_ID_2, sourceExchange2), + 1); + + sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + assertTrue(scheduler.isBlocked().isDone()); + + scheduler.schedule(); + + ListenableFuture blocked = scheduler.isBlocked(); + // waiting for tasks to finish + assertFalse(blocked.isDone()); + + scheduler.reportTaskFailure(getTaskId(0, 0), new RuntimeException("some failure")); + assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED); + + assertTrue(blocked.isDone()); + scheduler.schedule(); + + assertThat(remoteTaskFactory.getTasks()).containsKey(getTaskId(0, 1)); + + remoteTaskFactory.getTasks().get(getTaskId(0, 1)).finish(); + remoteTaskFactory.getTasks().get(getTaskId(1, 0)).finish(); + + assertTrue(scheduler.isBlocked().isDone()); + assertTrue(scheduler.isFinished()); + } + + @Test + public void testCancellation() + throws Exception + { + testCancellation(true); + testCancellation(false); + } + + private void testCancellation(boolean abort) + throws Exception + { + TestingRemoteTaskFactory remoteTaskFactory = new TestingRemoteTaskFactory(); + TestingTaskSourceFactory taskSourceFactory = createTaskSourceFactory(3, 1); + TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of( + NODE_1, ImmutableList.of(CATALOG), + NODE_2, ImmutableList.of(CATALOG))); + + TestingExchange sourceExchange1 = new TestingExchange(false); + TestingExchange sourceExchange2 = new TestingExchange(false); + + NodeAllocator nodeAllocator = createNodeAllocator(nodeSupplier); + FaultTolerantStageScheduler scheduler = createFaultTolerantTaskScheduler( + remoteTaskFactory, + taskSourceFactory, + nodeAllocator, + TaskLifecycleListener.NO_OP, + Optional.empty(), + ImmutableMap.of(SOURCE_FRAGMENT_ID_1, sourceExchange1, SOURCE_FRAGMENT_ID_2, sourceExchange2), + 0); + + sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + assertTrue(scheduler.isBlocked().isDone()); + + scheduler.schedule(); + + ListenableFuture blocked = scheduler.isBlocked(); + // waiting on node acquisition + assertFalse(blocked.isDone()); + + ListenableFuture acquireNode1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), Optional.empty())); + ListenableFuture acquireNode2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), Optional.empty())); + + if (abort) { + scheduler.abort(); + } + else { + scheduler.cancel(); + } + + assertTrue(blocked.isDone()); + assertTrue(acquireNode1.isDone()); + assertTrue(acquireNode2.isDone()); + + scheduler.schedule(); + + assertTrue(scheduler.isBlocked().isDone()); + assertFalse(scheduler.isFinished()); + } + + private FaultTolerantStageScheduler createFaultTolerantTaskScheduler( + RemoteTaskFactory remoteTaskFactory, + TaskSourceFactory taskSourceFactory, + NodeAllocator nodeAllocator, + TaskLifecycleListener taskLifecycleListener, + Optional sinkExchange, + Map sourceExchanges, + int retryAttempts) + { + return new FaultTolerantStageScheduler( + SESSION, + createSqlStage(remoteTaskFactory), + new NoOpFailureDetector(), + taskSourceFactory, + nodeAllocator, + taskLifecycleListener, + sinkExchange, + Optional.empty(), + sourceExchanges, + Optional.empty(), + Optional.empty(), + retryAttempts); + } + + private SqlStage createSqlStage(RemoteTaskFactory remoteTaskFactory) + { + PlanFragment fragment = createPlanFragment(); + return SqlStage.createSqlStage( + STAGE_ID, + fragment, + ImmutableMap.of(), + remoteTaskFactory, + SESSION, + false, + nodeTaskMap, + directExecutor(), + new SplitSchedulerStats()); + } + + private PlanFragment createPlanFragment() + { + Symbol probeColumnSymbol = new Symbol("probe_column"); + Symbol buildColumnSymbol = new Symbol("build_column"); + TableScanNode tableScan = new TableScanNode( + TABLE_SCAN_NODE_ID, + TEST_TABLE_HANDLE, + ImmutableList.of(probeColumnSymbol), + ImmutableMap.of(probeColumnSymbol, new TestingColumnHandle("column")), + TupleDomain.none(), + Optional.empty(), + false, + Optional.empty()); + RemoteSourceNode remoteSource = new RemoteSourceNode( + new PlanNodeId("remote_source_id"), + ImmutableList.of(SOURCE_FRAGMENT_ID_1, SOURCE_FRAGMENT_ID_2), + ImmutableList.of(buildColumnSymbol), + Optional.empty(), + REPLICATE, + TASK); + return new PlanFragment( + FRAGMENT_ID, + new JoinNode( + new PlanNodeId("join_id"), + INNER, + tableScan, + remoteSource, + ImmutableList.of(), + tableScan.getOutputSymbols(), + remoteSource.getOutputSymbols(), + false, + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.of(REPLICATED), + Optional.empty(), + ImmutableMap.of(), + Optional.empty()), + ImmutableMap.of(probeColumnSymbol, VARCHAR, buildColumnSymbol, VARCHAR), + SOURCE_DISTRIBUTION, + ImmutableList.of(TABLE_SCAN_NODE_ID), + new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(probeColumnSymbol, buildColumnSymbol)), + ungroupedExecution(), + StatsAndCosts.empty(), + Optional.empty()); + } + + private static TestingTaskSourceFactory createTaskSourceFactory(int splitCount, int taskPerBatch) + { + return new TestingTaskSourceFactory(Optional.of(CATALOG), createSplits(splitCount), taskPerBatch); + } + + private static List createSplits(int count) + { + return ImmutableList.copyOf(limit(cycle(new Split(CATALOG, createRemoteSplit(), Lifespan.taskWide())), count)); + } + + private NodeAllocator createNodeAllocator(TestingNodeSupplier nodeSupplier) + { + NodeScheduler nodeScheduler = new NodeScheduler(new TestingNodeSelectorFactory(NODE_1, nodeSupplier)); + return new FixedCountNodeAllocator(nodeScheduler, SESSION, 1); + } + + private static TaskId getTaskId(int partitionId, int attemptId) + { + return new TaskId(STAGE_ID, partitionId, attemptId); + } +} diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFixedCountNodeAllocator.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFixedCountNodeAllocator.java new file mode 100644 index 000000000000..6ec7d94b03cd --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFixedCountNodeAllocator.java @@ -0,0 +1,362 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.util.concurrent.ListenableFuture; +import io.trino.Session; +import io.trino.client.NodeVersion; +import io.trino.connector.CatalogName; +import io.trino.execution.scheduler.TestingNodeSelectorFactory.TestingNodeSupplier; +import io.trino.metadata.InternalNode; +import io.trino.spi.HostAddress; +import org.testng.annotations.Test; + +import java.net.URI; +import java.util.Optional; + +import static io.trino.testing.TestingSession.testSessionBuilder; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +public class TestFixedCountNodeAllocator +{ + private static final Session SESSION = testSessionBuilder().build(); + + private static final HostAddress NODE_1_ADDRESS = HostAddress.fromParts("127.0.0.1", 8080); + private static final HostAddress NODE_2_ADDRESS = HostAddress.fromParts("127.0.0.1", 8081); + private static final HostAddress NODE_3_ADDRESS = HostAddress.fromParts("127.0.0.1", 8082); + + private static final InternalNode NODE_1 = new InternalNode("node-1", URI.create("local://" + NODE_1_ADDRESS), NodeVersion.UNKNOWN, false); + private static final InternalNode NODE_2 = new InternalNode("node-2", URI.create("local://" + NODE_2_ADDRESS), NodeVersion.UNKNOWN, false); + private static final InternalNode NODE_3 = new InternalNode("node-3", URI.create("local://" + NODE_3_ADDRESS), NodeVersion.UNKNOWN, false); + + private static final CatalogName CATALOG_1 = new CatalogName("catalog1"); + private static final CatalogName CATALOG_2 = new CatalogName("catalog2"); + + @Test + public void testSingleNode() + throws Exception + { + TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of(NODE_1, ImmutableList.of())); + + try (NodeAllocator nodeAllocator = createNodeAllocator(nodeSupplier, 1)) { + ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), Optional.empty())); + assertTrue(acquire1.isDone()); + assertEquals(acquire1.get(), NODE_1); + + ListenableFuture acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), Optional.empty())); + assertFalse(acquire2.isDone()); + + nodeAllocator.release(NODE_1); + + assertTrue(acquire2.isDone()); + assertEquals(acquire2.get(), NODE_1); + } + + try (NodeAllocator nodeAllocator = createNodeAllocator(nodeSupplier, 2)) { + ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), Optional.empty())); + assertTrue(acquire1.isDone()); + assertEquals(acquire1.get(), NODE_1); + + ListenableFuture acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), Optional.empty())); + assertTrue(acquire2.isDone()); + assertEquals(acquire2.get(), NODE_1); + + ListenableFuture acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), Optional.empty())); + assertFalse(acquire3.isDone()); + + ListenableFuture acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), Optional.empty())); + assertFalse(acquire4.isDone()); + + nodeAllocator.release(NODE_1); + assertTrue(acquire3.isDone()); + assertEquals(acquire3.get(), NODE_1); + + nodeAllocator.release(NODE_1); + assertTrue(acquire4.isDone()); + assertEquals(acquire4.get(), NODE_1); + } + } + + @Test + public void testMultipleNodes() + throws Exception + { + TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of(NODE_1, ImmutableList.of(), NODE_2, ImmutableList.of())); + + try (NodeAllocator nodeAllocator = createNodeAllocator(nodeSupplier, 1)) { + ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), Optional.empty())); + assertTrue(acquire1.isDone()); + assertEquals(acquire1.get(), NODE_1); + + ListenableFuture acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), Optional.empty())); + assertTrue(acquire2.isDone()); + assertEquals(acquire2.get(), NODE_2); + + ListenableFuture acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), Optional.empty())); + assertFalse(acquire3.isDone()); + + ListenableFuture acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), Optional.empty())); + assertFalse(acquire4.isDone()); + + ListenableFuture acquire5 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), Optional.empty())); + assertFalse(acquire5.isDone()); + + nodeAllocator.release(NODE_2); + assertTrue(acquire3.isDone()); + assertEquals(acquire3.get(), NODE_2); + + nodeAllocator.release(NODE_1); + assertTrue(acquire4.isDone()); + assertEquals(acquire4.get(), NODE_1); + + nodeAllocator.release(NODE_1); + assertTrue(acquire5.isDone()); + assertEquals(acquire5.get(), NODE_1); + } + + try (NodeAllocator nodeAllocator = createNodeAllocator(nodeSupplier, 2)) { + ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), Optional.empty())); + assertTrue(acquire1.isDone()); + assertEquals(acquire1.get(), NODE_1); + + ListenableFuture acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), Optional.empty())); + assertTrue(acquire2.isDone()); + assertEquals(acquire2.get(), NODE_2); + + ListenableFuture acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), Optional.empty())); + assertTrue(acquire3.isDone()); + assertEquals(acquire3.get(), NODE_1); + + ListenableFuture acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), Optional.empty())); + assertTrue(acquire4.isDone()); + assertEquals(acquire4.get(), NODE_2); + + ListenableFuture acquire5 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), Optional.empty())); + assertFalse(acquire5.isDone()); + + ListenableFuture acquire6 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), Optional.empty())); + assertFalse(acquire6.isDone()); + + nodeAllocator.release(NODE_2); + assertTrue(acquire5.isDone()); + assertEquals(acquire5.get(), NODE_2); + + nodeAllocator.release(NODE_1); + assertTrue(acquire6.isDone()); + assertEquals(acquire6.get(), NODE_1); + + ListenableFuture acquire7 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), Optional.empty())); + assertFalse(acquire7.isDone()); + + nodeAllocator.release(NODE_1); + assertTrue(acquire7.isDone()); + assertEquals(acquire7.get(), NODE_1); + + nodeAllocator.release(NODE_1); + nodeAllocator.release(NODE_2); + nodeAllocator.release(NODE_2); + + ListenableFuture acquire8 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), Optional.empty())); + assertTrue(acquire8.isDone()); + assertEquals(acquire8.get(), NODE_2); + } + } + + @Test + public void testCatalogRequirement() + throws Exception + { + TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of( + NODE_1, ImmutableList.of(CATALOG_1), + NODE_2, ImmutableList.of(CATALOG_2), + NODE_3, ImmutableList.of(CATALOG_1, CATALOG_2))); + + try (NodeAllocator nodeAllocator = createNodeAllocator(nodeSupplier, 1)) { + ListenableFuture catalog1acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_1), Optional.empty())); + assertTrue(catalog1acquire1.isDone()); + assertEquals(catalog1acquire1.get(), NODE_1); + + ListenableFuture catalog1acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_1), Optional.empty())); + assertTrue(catalog1acquire2.isDone()); + assertEquals(catalog1acquire2.get(), NODE_3); + + ListenableFuture catalog1acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_1), Optional.empty())); + assertFalse(catalog1acquire3.isDone()); + + ListenableFuture catalog2acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_2), Optional.empty())); + assertTrue(catalog2acquire1.isDone()); + assertEquals(catalog2acquire1.get(), NODE_2); + + ListenableFuture catalog2acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_2), Optional.empty())); + assertFalse(catalog2acquire2.isDone()); + + nodeAllocator.release(NODE_2); + assertFalse(catalog1acquire3.isDone()); + assertTrue(catalog2acquire2.isDone()); + assertEquals(catalog2acquire2.get(), NODE_2); + + nodeAllocator.release(NODE_1); + assertTrue(catalog1acquire3.isDone()); + assertEquals(catalog1acquire3.get(), NODE_1); + + ListenableFuture catalog1acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_1), Optional.empty())); + assertFalse(catalog1acquire4.isDone()); + + ListenableFuture catalog2acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_2), Optional.empty())); + assertFalse(catalog2acquire4.isDone()); + + nodeAllocator.release(NODE_3); + assertFalse(catalog2acquire4.isDone()); + assertTrue(catalog1acquire4.isDone()); + assertEquals(catalog1acquire4.get(), NODE_3); + + nodeAllocator.release(NODE_3); + assertTrue(catalog2acquire4.isDone()); + assertEquals(catalog2acquire4.get(), NODE_3); + } + } + + @Test + public void testCancellation() + throws Exception + { + TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of(NODE_1, ImmutableList.of())); + + try (NodeAllocator nodeAllocator = createNodeAllocator(nodeSupplier, 1)) { + ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), Optional.empty())); + assertTrue(acquire1.isDone()); + assertEquals(acquire1.get(), NODE_1); + + ListenableFuture acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), Optional.empty())); + assertFalse(acquire2.isDone()); + + ListenableFuture acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), Optional.empty())); + assertFalse(acquire3.isDone()); + + acquire2.cancel(true); + + nodeAllocator.release(NODE_1); + assertTrue(acquire3.isDone()); + assertEquals(acquire3.get(), NODE_1); + } + } + + @Test + public void testAddNode() + throws Exception + { + TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of(NODE_1, ImmutableList.of())); + + try (NodeAllocator nodeAllocator = createNodeAllocator(nodeSupplier, 1)) { + ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), Optional.empty())); + assertTrue(acquire1.isDone()); + assertEquals(acquire1.get(), NODE_1); + + ListenableFuture acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), Optional.empty())); + assertFalse(acquire2.isDone()); + + nodeSupplier.addNode(NODE_2, ImmutableList.of()); + nodeAllocator.updateNodes(); + + assertEquals(acquire2.get(10, SECONDS), NODE_2); + } + } + + @Test + public void testRemoveNode() + throws Exception + { + TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of(NODE_1, ImmutableList.of())); + + try (NodeAllocator nodeAllocator = createNodeAllocator(nodeSupplier, 1)) { + ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), Optional.empty())); + assertTrue(acquire1.isDone()); + assertEquals(acquire1.get(), NODE_1); + + ListenableFuture acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), Optional.empty())); + assertFalse(acquire2.isDone()); + + nodeSupplier.removeNode(NODE_1); + nodeSupplier.addNode(NODE_2, ImmutableList.of()); + nodeAllocator.updateNodes(); + + assertEquals(acquire2.get(10, SECONDS), NODE_2); + + ListenableFuture acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), Optional.empty())); + assertFalse(acquire3.isDone()); + + nodeAllocator.release(NODE_1); + assertFalse(acquire3.isDone()); + } + } + + @Test + public void testAddressRequirement() + throws Exception + { + TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of(NODE_1, ImmutableList.of(), NODE_2, ImmutableList.of())); + try (FixedCountNodeAllocator nodeAllocator = createNodeAllocator(nodeSupplier, 1)) { + ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), Optional.of(ImmutableSet.of(NODE_2_ADDRESS)))); + assertTrue(acquire1.isDone()); + assertEquals(acquire1.get(), NODE_2); + + ListenableFuture acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), Optional.of(ImmutableSet.of(NODE_2_ADDRESS)))); + assertFalse(acquire2.isDone()); + + nodeAllocator.release(NODE_2); + + assertTrue(acquire2.isDone()); + assertEquals(acquire2.get(), NODE_2); + + ListenableFuture acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), Optional.of(ImmutableSet.of(NODE_3_ADDRESS)))); + assertTrue(acquire3.isDone()); + assertThatThrownBy(acquire3::get) + .hasMessageContaining("No nodes available to run query"); + + nodeSupplier.addNode(NODE_3, ImmutableList.of()); + nodeAllocator.updateNodes(); + + ListenableFuture acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), Optional.of(ImmutableSet.of(NODE_3_ADDRESS)))); + assertTrue(acquire4.isDone()); + assertEquals(acquire4.get(), NODE_3); + + ListenableFuture acquire5 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), Optional.of(ImmutableSet.of(NODE_3_ADDRESS)))); + assertFalse(acquire5.isDone()); + + nodeSupplier.removeNode(NODE_3); + nodeAllocator.updateNodes(); + + assertTrue(acquire5.isDone()); + assertThatThrownBy(acquire5::get) + .hasMessageContaining("No nodes available to run query"); + } + } + + private FixedCountNodeAllocator createNodeAllocator(TestingNodeSupplier testingNodeSupplier, int maximumAllocationsPerNode) + { + return new FixedCountNodeAllocator(createNodeScheduler(testingNodeSupplier), SESSION, maximumAllocationsPerNode); + } + + private NodeScheduler createNodeScheduler(TestingNodeSupplier testingNodeSupplier) + { + return new NodeScheduler(new TestingNodeSelectorFactory(NODE_1, testingNodeSupplier)); + } +} diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestStageTaskSourceFactory.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestStageTaskSourceFactory.java new file mode 100644 index 000000000000..751793761864 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestStageTaskSourceFactory.java @@ -0,0 +1,603 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Multimap; +import io.airlift.units.DataSize; +import io.trino.connector.CatalogName; +import io.trino.execution.Lifespan; +import io.trino.execution.TableExecuteContextManager; +import io.trino.execution.scheduler.StageTaskSourceFactory.ArbitraryDistributionTaskSource; +import io.trino.execution.scheduler.StageTaskSourceFactory.HashDistributionTaskSource; +import io.trino.execution.scheduler.StageTaskSourceFactory.SingleDistributionTaskSource; +import io.trino.execution.scheduler.StageTaskSourceFactory.SourceDistributionTaskSource; +import io.trino.execution.scheduler.TestingExchange.TestingExchangeSourceHandle; +import io.trino.execution.scheduler.group.DynamicBucketNodeMap; +import io.trino.metadata.Split; +import io.trino.spi.HostAddress; +import io.trino.spi.QueryId; +import io.trino.spi.connector.ConnectorSplit; +import io.trino.spi.exchange.Exchange; +import io.trino.spi.exchange.ExchangeContext; +import io.trino.spi.exchange.ExchangeManager; +import io.trino.spi.exchange.ExchangeSourceHandle; +import io.trino.split.SplitSource; +import io.trino.sql.planner.plan.PlanFragmentId; +import io.trino.sql.planner.plan.PlanNodeId; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.OptionalInt; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static io.airlift.units.DataSize.Unit.BYTE; +import static java.util.Objects.requireNonNull; +import static org.assertj.core.api.Assertions.assertThat; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +public class TestStageTaskSourceFactory +{ + private static final PlanFragmentId FRAGMENT_ID_1 = new PlanFragmentId("1"); + private static final PlanFragmentId FRAGMENT_ID_2 = new PlanFragmentId("2"); + private static final PlanNodeId PLAN_NODE_1 = new PlanNodeId("planNode1"); + private static final PlanNodeId PLAN_NODE_2 = new PlanNodeId("planNode2"); + private static final PlanNodeId PLAN_NODE_3 = new PlanNodeId("planNode3"); + private static final PlanNodeId PLAN_NODE_4 = new PlanNodeId("planNode4"); + private static final PlanNodeId PLAN_NODE_5 = new PlanNodeId("planNode5"); + private static final CatalogName CATALOG = new CatalogName("catalog"); + + @Test + public void testSingleDistributionTaskSource() + { + Multimap sources = ImmutableListMultimap.builder() + .put(PLAN_NODE_1, new TestingExchangeSourceHandle(0, 123)) + .put(PLAN_NODE_2, new TestingExchangeSourceHandle(0, 321)) + .put(PLAN_NODE_1, new TestingExchangeSourceHandle(0, 222)) + .build(); + TaskSource taskSource = new SingleDistributionTaskSource(sources); + + assertFalse(taskSource.isFinished()); + + List tasks = taskSource.getMoreTasks(); + assertThat(tasks).hasSize(1); + assertTrue(taskSource.isFinished()); + + TaskDescriptor task = tasks.get(0); + assertThat(task.getNodeRequirements().getCatalogName()).isEmpty(); + assertThat(task.getNodeRequirements().getAddresses()).isEmpty(); + assertEquals(task.getPartitionId(), 0); + assertEquals(task.getExchangeSourceHandles(), sources); + assertEquals(task.getSplits(), ImmutableListMultimap.of()); + } + + @Test + public void testArbitraryDistributionTaskSource() + { + ExchangeManager splittingExchangeManager = new TestingExchangeManager(true); + ExchangeManager nonSplittingExchangeManager = new TestingExchangeManager(false); + + TaskSource taskSource = new ArbitraryDistributionTaskSource(ImmutableMap.of(), ImmutableMap.of(), ImmutableListMultimap.of(), DataSize.of(3, BYTE)); + assertFalse(taskSource.isFinished()); + List tasks = taskSource.getMoreTasks(); + assertThat(tasks).isEmpty(); + assertTrue(taskSource.isFinished()); + + Multimap sources = ImmutableListMultimap.of(FRAGMENT_ID_1, new TestingExchangeSourceHandle(0, 3)); + Exchange exchange = splittingExchangeManager.create(new ExchangeContext(new QueryId("query"), 0), 3); + taskSource = new ArbitraryDistributionTaskSource( + ImmutableMap.of(FRAGMENT_ID_1, PLAN_NODE_1), + ImmutableMap.of(FRAGMENT_ID_1, exchange), + sources, + DataSize.of(3, BYTE)); + tasks = taskSource.getMoreTasks(); + assertTrue(taskSource.isFinished()); + assertThat(tasks).hasSize(1); + assertEquals(tasks, ImmutableList.of(new TaskDescriptor( + 0, + ImmutableListMultimap.of(), + ImmutableListMultimap.of(PLAN_NODE_1, new TestingExchangeSourceHandle(0, 3)), + new NodeRequirements(Optional.empty(), Optional.empty())))); + + sources = ImmutableListMultimap.of(FRAGMENT_ID_1, new TestingExchangeSourceHandle(0, 123)); + exchange = nonSplittingExchangeManager.create(new ExchangeContext(new QueryId("query"), 0), 3); + taskSource = new ArbitraryDistributionTaskSource( + ImmutableMap.of(FRAGMENT_ID_1, PLAN_NODE_1), + ImmutableMap.of(FRAGMENT_ID_1, exchange), + sources, + DataSize.of(3, BYTE)); + tasks = taskSource.getMoreTasks(); + assertEquals(tasks, ImmutableList.of(new TaskDescriptor( + 0, + ImmutableListMultimap.of(), + ImmutableListMultimap.of(PLAN_NODE_1, new TestingExchangeSourceHandle(0, 123)), + new NodeRequirements(Optional.empty(), Optional.empty())))); + + sources = ImmutableListMultimap.of( + FRAGMENT_ID_1, new TestingExchangeSourceHandle(0, 123), + FRAGMENT_ID_2, new TestingExchangeSourceHandle(0, 321)); + exchange = nonSplittingExchangeManager.create(new ExchangeContext(new QueryId("query"), 0), 3); + taskSource = new ArbitraryDistributionTaskSource( + ImmutableMap.of(FRAGMENT_ID_1, PLAN_NODE_1, FRAGMENT_ID_2, PLAN_NODE_2), + ImmutableMap.of(FRAGMENT_ID_1, exchange, FRAGMENT_ID_2, exchange), + sources, + DataSize.of(3, BYTE)); + tasks = taskSource.getMoreTasks(); + assertEquals(tasks, ImmutableList.of( + new TaskDescriptor( + 0, + ImmutableListMultimap.of(), + ImmutableListMultimap.of(PLAN_NODE_1, new TestingExchangeSourceHandle(0, 123)), + new NodeRequirements(Optional.empty(), Optional.empty())), + new TaskDescriptor( + 1, + ImmutableListMultimap.of(), + ImmutableListMultimap.of(PLAN_NODE_2, new TestingExchangeSourceHandle(0, 321)), + new NodeRequirements(Optional.empty(), Optional.empty())))); + + sources = ImmutableListMultimap.of( + FRAGMENT_ID_1, new TestingExchangeSourceHandle(0, 1), + FRAGMENT_ID_1, new TestingExchangeSourceHandle(0, 2), + FRAGMENT_ID_2, new TestingExchangeSourceHandle(0, 4)); + exchange = splittingExchangeManager.create(new ExchangeContext(new QueryId("query"), 0), 3); + taskSource = new ArbitraryDistributionTaskSource( + ImmutableMap.of(FRAGMENT_ID_1, PLAN_NODE_1, FRAGMENT_ID_2, PLAN_NODE_2), + ImmutableMap.of(FRAGMENT_ID_1, exchange, FRAGMENT_ID_2, exchange), + sources, + DataSize.of(3, BYTE)); + tasks = taskSource.getMoreTasks(); + assertEquals(tasks, ImmutableList.of( + new TaskDescriptor( + 0, + ImmutableListMultimap.of(), + ImmutableListMultimap.of(PLAN_NODE_1, new TestingExchangeSourceHandle(0, 1), PLAN_NODE_1, new TestingExchangeSourceHandle(0, 2)), + new NodeRequirements(Optional.empty(), Optional.empty())), + new TaskDescriptor( + 1, + ImmutableListMultimap.of(), + ImmutableListMultimap.of(PLAN_NODE_2, new TestingExchangeSourceHandle(0, 3)), + new NodeRequirements(Optional.empty(), Optional.empty())), + new TaskDescriptor( + 2, + ImmutableListMultimap.of(), + ImmutableListMultimap.of(PLAN_NODE_2, new TestingExchangeSourceHandle(0, 1)), + new NodeRequirements(Optional.empty(), Optional.empty())))); + + sources = ImmutableListMultimap.of( + FRAGMENT_ID_1, new TestingExchangeSourceHandle(0, 1), + FRAGMENT_ID_1, new TestingExchangeSourceHandle(0, 3), + FRAGMENT_ID_2, new TestingExchangeSourceHandle(0, 4)); + exchange = splittingExchangeManager.create(new ExchangeContext(new QueryId("query"), 0), 3); + taskSource = new ArbitraryDistributionTaskSource( + ImmutableMap.of(FRAGMENT_ID_1, PLAN_NODE_1, FRAGMENT_ID_2, PLAN_NODE_2), + ImmutableMap.of(FRAGMENT_ID_1, exchange, FRAGMENT_ID_2, exchange), + sources, + DataSize.of(3, BYTE)); + tasks = taskSource.getMoreTasks(); + assertEquals(tasks, ImmutableList.of( + new TaskDescriptor( + 0, + ImmutableListMultimap.of(), + ImmutableListMultimap.of(PLAN_NODE_1, new TestingExchangeSourceHandle(0, 1)), + new NodeRequirements(Optional.empty(), Optional.empty())), + new TaskDescriptor( + 1, + ImmutableListMultimap.of(), + ImmutableListMultimap.of(PLAN_NODE_1, new TestingExchangeSourceHandle(0, 3)), + new NodeRequirements(Optional.empty(), Optional.empty())), + new TaskDescriptor( + 2, + ImmutableListMultimap.of(), + ImmutableListMultimap.of(PLAN_NODE_2, new TestingExchangeSourceHandle(0, 3)), + new NodeRequirements(Optional.empty(), Optional.empty())), + new TaskDescriptor( + 3, + ImmutableListMultimap.of(), + ImmutableListMultimap.of(PLAN_NODE_2, new TestingExchangeSourceHandle(0, 1)), + new NodeRequirements(Optional.empty(), Optional.empty())))); + } + + @Test + public void testHashDistributionTaskSource() + { + TaskSource taskSource = createHashDistributionTaskSource( + ImmutableMap.of(), + ImmutableListMultimap.of(), + ImmutableListMultimap.of(), + 1, + Optional.empty(), + Optional.empty()); + assertFalse(taskSource.isFinished()); + assertEquals(taskSource.getMoreTasks(), ImmutableList.of()); + assertTrue(taskSource.isFinished()); + + taskSource = createHashDistributionTaskSource( + ImmutableMap.of(), + ImmutableListMultimap.of( + PLAN_NODE_1, new TestingExchangeSourceHandle(0, 1), + PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1), + PLAN_NODE_2, new TestingExchangeSourceHandle(0, 1), + PLAN_NODE_2, new TestingExchangeSourceHandle(3, 1)), + ImmutableListMultimap.of( + PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), + 1, + Optional.empty(), + Optional.empty()); + assertFalse(taskSource.isFinished()); + assertEquals(taskSource.getMoreTasks(), ImmutableList.of( + new TaskDescriptor(0, ImmutableListMultimap.of(), ImmutableListMultimap.of( + PLAN_NODE_1, new TestingExchangeSourceHandle(0, 1), + PLAN_NODE_2, new TestingExchangeSourceHandle(0, 1), + PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), Optional.empty())), + new TaskDescriptor(1, ImmutableListMultimap.of(), ImmutableListMultimap.of( + PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1), + PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), Optional.empty())), + new TaskDescriptor(2, ImmutableListMultimap.of(), ImmutableListMultimap.of( + PLAN_NODE_2, new TestingExchangeSourceHandle(3, 1), + PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), Optional.empty())))); + assertTrue(taskSource.isFinished()); + + Split bucketedSplit1 = createBucketedSplit(0, 0); + Split bucketedSplit2 = createBucketedSplit(0, 2); + Split bucketedSplit3 = createBucketedSplit(0, 3); + Split bucketedSplit4 = createBucketedSplit(0, 1); + + taskSource = createHashDistributionTaskSource( + ImmutableMap.of( + PLAN_NODE_4, new TestingSplitSource(CATALOG, ImmutableList.of(bucketedSplit1, bucketedSplit2, bucketedSplit3)), + PLAN_NODE_5, new TestingSplitSource(CATALOG, ImmutableList.of(bucketedSplit4))), + ImmutableListMultimap.of(), + ImmutableListMultimap.of( + PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), + 1, + Optional.empty(), + Optional.of(getTestingBucketNodeMap(4))); + assertFalse(taskSource.isFinished()); + assertEquals(taskSource.getMoreTasks(), ImmutableList.of( + new TaskDescriptor( + 0, + ImmutableListMultimap.of( + PLAN_NODE_4, bucketedSplit1), + ImmutableListMultimap.of( + PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), Optional.empty())), + new TaskDescriptor( + 1, + ImmutableListMultimap.of( + PLAN_NODE_5, bucketedSplit4), + ImmutableListMultimap.of( + PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), Optional.empty())), + new TaskDescriptor( + 2, + ImmutableListMultimap.of( + PLAN_NODE_4, bucketedSplit2), + ImmutableListMultimap.of( + PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), Optional.empty())), + new TaskDescriptor( + 3, + ImmutableListMultimap.of( + PLAN_NODE_4, bucketedSplit3), + ImmutableListMultimap.of( + PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), Optional.empty())))); + assertTrue(taskSource.isFinished()); + + taskSource = createHashDistributionTaskSource( + ImmutableMap.of( + PLAN_NODE_4, new TestingSplitSource(CATALOG, ImmutableList.of(bucketedSplit1, bucketedSplit2, bucketedSplit3)), + PLAN_NODE_5, new TestingSplitSource(CATALOG, ImmutableList.of(bucketedSplit4))), + ImmutableListMultimap.of( + PLAN_NODE_1, new TestingExchangeSourceHandle(0, 1), + PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1), + PLAN_NODE_2, new TestingExchangeSourceHandle(0, 1), + PLAN_NODE_2, new TestingExchangeSourceHandle(3, 1)), + ImmutableListMultimap.of( + PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), + 1, + Optional.of(new int[] {0, 1, 2, 3}), + Optional.of(getTestingBucketNodeMap(4))); + assertFalse(taskSource.isFinished()); + assertEquals(taskSource.getMoreTasks(), ImmutableList.of( + new TaskDescriptor( + 0, + ImmutableListMultimap.of( + PLAN_NODE_4, bucketedSplit1), + ImmutableListMultimap.of( + PLAN_NODE_1, new TestingExchangeSourceHandle(0, 1), + PLAN_NODE_2, new TestingExchangeSourceHandle(0, 1), + PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), Optional.empty())), + new TaskDescriptor( + 1, + ImmutableListMultimap.of( + PLAN_NODE_5, bucketedSplit4), + ImmutableListMultimap.of( + PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1), + PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), Optional.empty())), + new TaskDescriptor( + 2, + ImmutableListMultimap.of( + PLAN_NODE_4, bucketedSplit2), + ImmutableListMultimap.of( + PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), Optional.empty())), + new TaskDescriptor( + 3, + ImmutableListMultimap.of( + PLAN_NODE_4, bucketedSplit3), + ImmutableListMultimap.of( + PLAN_NODE_2, new TestingExchangeSourceHandle(3, 1), + PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), Optional.empty())))); + assertTrue(taskSource.isFinished()); + + taskSource = createHashDistributionTaskSource( + ImmutableMap.of( + PLAN_NODE_4, new TestingSplitSource(CATALOG, ImmutableList.of(bucketedSplit1, bucketedSplit2, bucketedSplit3)), + PLAN_NODE_5, new TestingSplitSource(CATALOG, ImmutableList.of(bucketedSplit4))), + ImmutableListMultimap.of( + PLAN_NODE_1, new TestingExchangeSourceHandle(0, 1), + PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1), + PLAN_NODE_2, new TestingExchangeSourceHandle(0, 1)), + ImmutableListMultimap.of( + PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), + 2, + Optional.of(new int[] {0, 1, 0, 1}), + Optional.of(getTestingBucketNodeMap(4))); + assertFalse(taskSource.isFinished()); + assertEquals(taskSource.getMoreTasks(), ImmutableList.of( + new TaskDescriptor( + 0, + ImmutableListMultimap.of( + PLAN_NODE_4, bucketedSplit1, + PLAN_NODE_4, bucketedSplit2), + ImmutableListMultimap.of( + PLAN_NODE_1, new TestingExchangeSourceHandle(0, 1), + PLAN_NODE_2, new TestingExchangeSourceHandle(0, 1), + PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), Optional.empty())), + new TaskDescriptor( + 1, + ImmutableListMultimap.of( + PLAN_NODE_4, bucketedSplit3, + PLAN_NODE_5, bucketedSplit4), + ImmutableListMultimap.of( + PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1), + PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), Optional.empty())))); + assertTrue(taskSource.isFinished()); + } + + private static HashDistributionTaskSource createHashDistributionTaskSource( + Map splitSources, + Multimap partitionedExchangeSources, + Multimap replicatedExchangeSources, + int splitBatchSize, + Optional bucketToPartitionMap, + Optional bucketNodeMap) + { + return new HashDistributionTaskSource( + splitSources, + partitionedExchangeSources, + replicatedExchangeSources, + splitBatchSize, + (getSplitsTime) -> {}, + bucketToPartitionMap, + bucketNodeMap, + Optional.of(CATALOG)); + } + + @Test + public void testSourceDistributionTaskSource() + { + TaskSource taskSource = createSourceDistributionTaskSource(ImmutableList.of(), ImmutableListMultimap.of(), 2, 3); + assertFalse(taskSource.isFinished()); + assertEquals(taskSource.getMoreTasks(), ImmutableList.of()); + assertTrue(taskSource.isFinished()); + + Split split1 = createSplit(1); + Split split2 = createSplit(2); + Split split3 = createSplit(3); + + taskSource = createSourceDistributionTaskSource( + ImmutableList.of(split1), + ImmutableListMultimap.of(), + 2, + 2); + assertEquals(taskSource.getMoreTasks(), ImmutableList.of(new TaskDescriptor( + 0, + ImmutableListMultimap.of(PLAN_NODE_1, split1), + ImmutableListMultimap.of(), + new NodeRequirements(Optional.of(CATALOG), Optional.empty())))); + assertTrue(taskSource.isFinished()); + + taskSource = createSourceDistributionTaskSource( + ImmutableList.of(split1, split2, split3), + ImmutableListMultimap.of(), + 3, + 2); + assertEquals(taskSource.getMoreTasks(), ImmutableList.of( + new TaskDescriptor(0, ImmutableListMultimap.of(PLAN_NODE_1, split1, PLAN_NODE_1, split2), ImmutableListMultimap.of(), new NodeRequirements(Optional.of(CATALOG), Optional.empty())))); + assertEquals(taskSource.getMoreTasks(), ImmutableList.of( + new TaskDescriptor(1, ImmutableListMultimap.of(PLAN_NODE_1, split3), ImmutableListMultimap.of(), new NodeRequirements(Optional.of(CATALOG), Optional.empty())))); + assertTrue(taskSource.isFinished()); + + ImmutableListMultimap replicatedSources = ImmutableListMultimap.of(PLAN_NODE_2, new TestingExchangeSourceHandle(0, 1)); + taskSource = createSourceDistributionTaskSource( + ImmutableList.of(split1, split2, split3), + replicatedSources, + 2, + 2); + assertEquals(taskSource.getMoreTasks(), ImmutableList.of( + new TaskDescriptor(0, ImmutableListMultimap.of(PLAN_NODE_1, split1, PLAN_NODE_1, split2), replicatedSources, new NodeRequirements(Optional.of(CATALOG), Optional.empty())))); + assertFalse(taskSource.isFinished()); + assertEquals(taskSource.getMoreTasks(), ImmutableList.of( + new TaskDescriptor(1, ImmutableListMultimap.of(PLAN_NODE_1, split3), replicatedSources, new NodeRequirements(Optional.of(CATALOG), Optional.empty())))); + assertTrue(taskSource.isFinished()); + + // non remotely accessible splits + ImmutableList splits = ImmutableList.of( + createSplit(1, ImmutableList.of(HostAddress.fromString("host1:8080"), HostAddress.fromString("host2:8080"))), + createSplit(2, ImmutableList.of(HostAddress.fromString("host2:8080"))), + createSplit(3, ImmutableList.of(HostAddress.fromString("host1:8080"), HostAddress.fromString("host3:8080"))), + createSplit(4, ImmutableList.of(HostAddress.fromString("host3:8080"), HostAddress.fromString("host1:8080"))), + createSplit(5, ImmutableList.of(HostAddress.fromString("host1:8080"), HostAddress.fromString("host2:8080"))), + createSplit(6, ImmutableList.of(HostAddress.fromString("host2:8080"), HostAddress.fromString("host3:8080"))), + createSplit(7, ImmutableList.of(HostAddress.fromString("host3:8080"), HostAddress.fromString("host4:8080")))); + taskSource = createSourceDistributionTaskSource(splits, ImmutableListMultimap.of(), 3, 2); + + List tasks = taskSource.getMoreTasks(); + assertEquals(tasks.size(), 1); + assertEquals(tasks.get(0).getNodeRequirements().getAddresses(), Optional.of(ImmutableSet.of(HostAddress.fromString("host1:8080")))); + assertThat(tasks.get(0).getSplits().get(PLAN_NODE_1)).containsExactlyInAnyOrder(splits.get(0), splits.get(2)); + assertFalse(taskSource.isFinished()); + + tasks = taskSource.getMoreTasks(); + assertEquals(tasks.size(), 1); + assertEquals(tasks.get(0).getNodeRequirements().getAddresses(), Optional.of(ImmutableSet.of(HostAddress.fromString("host1:8080")))); + assertThat(tasks.get(0).getSplits().get(PLAN_NODE_1)).containsExactlyInAnyOrder(splits.get(3), splits.get(4)); + assertFalse(taskSource.isFinished()); + + tasks = taskSource.getMoreTasks(); + assertEquals(tasks.size(), 1); + assertEquals(tasks.get(0).getNodeRequirements().getAddresses(), Optional.of(ImmutableSet.of(HostAddress.fromString("host2:8080")))); + assertThat(tasks.get(0).getSplits().get(PLAN_NODE_1)).containsExactlyInAnyOrder(splits.get(1), splits.get(5)); + assertFalse(taskSource.isFinished()); + + tasks = taskSource.getMoreTasks(); + assertEquals(tasks.size(), 1); + assertEquals(tasks.get(0).getNodeRequirements().getAddresses(), Optional.of(ImmutableSet.of(HostAddress.fromString("host3:8080")))); + assertThat(tasks.get(0).getSplits().get(PLAN_NODE_1)).containsExactlyInAnyOrder(splits.get(6)); + assertTrue(taskSource.isFinished()); + } + + private static SourceDistributionTaskSource createSourceDistributionTaskSource( + List splits, + Multimap replicatedSources, + int splitBatchSize, + int splitsPerTask) + { + return new SourceDistributionTaskSource( + new QueryId("query"), + PLAN_NODE_1, + new TableExecuteContextManager(), + new TestingSplitSource(CATALOG, splits), + replicatedSources, + splitBatchSize, + (getSplitsTime) -> {}, + Optional.of(CATALOG), + splitsPerTask); + } + + private static Split createSplit(int id) + { + return new Split(CATALOG, new TestingConnectorSplit(id, OptionalInt.empty(), Optional.empty()), Lifespan.taskWide()); + } + + private static Split createSplit(int id, List addresses) + { + return new Split(CATALOG, new TestingConnectorSplit(id, OptionalInt.empty(), Optional.of(addresses)), Lifespan.taskWide()); + } + + private static Split createBucketedSplit(int id, int bucket) + { + return createBucketedSplit(id, bucket, Optional.empty()); + } + + private static Split createBucketedSplit(int id, int bucket, Optional> addresses) + { + return new Split(CATALOG, new TestingConnectorSplit(id, OptionalInt.of(bucket), addresses), Lifespan.taskWide()); + } + + private static BucketNodeMap getTestingBucketNodeMap(int bucketCount) + { + return new DynamicBucketNodeMap((split) -> { + TestingConnectorSplit testingConnectorSplit = (TestingConnectorSplit) split.getConnectorSplit(); + return testingConnectorSplit.getBucket().getAsInt(); + }, bucketCount); + } + + private static class TestingConnectorSplit + implements ConnectorSplit + { + private final int id; + private final OptionalInt bucket; + private final Optional> addresses; + + public TestingConnectorSplit(int id, OptionalInt bucket, Optional> addresses) + { + this.id = id; + this.bucket = requireNonNull(bucket, "bucket is null"); + this.addresses = requireNonNull(addresses, "addresses is null").map(ImmutableList::copyOf); + } + + public int getId() + { + return id; + } + + public OptionalInt getBucket() + { + return bucket; + } + + @Override + public boolean isRemotelyAccessible() + { + return addresses.isEmpty(); + } + + @Override + public List getAddresses() + { + return addresses.orElse(ImmutableList.of()); + } + + @Override + public Object getInfo() + { + return null; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + TestingConnectorSplit that = (TestingConnectorSplit) o; + return id == that.id && Objects.equals(bucket, that.bucket) && Objects.equals(addresses, that.addresses); + } + + @Override + public int hashCode() + { + return Objects.hash(id, bucket, addresses); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("id", id) + .add("bucket", bucket) + .add("addresses", addresses) + .toString(); + } + } +} diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingExchange.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingExchange.java new file mode 100644 index 000000000000..bc252f99f73e --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingExchange.java @@ -0,0 +1,280 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import io.trino.spi.exchange.Exchange; +import io.trino.spi.exchange.ExchangeSinkHandle; +import io.trino.spi.exchange.ExchangeSinkInstanceHandle; +import io.trino.spi.exchange.ExchangeSourceHandle; +import io.trino.spi.exchange.ExchangeSourceSplitter; +import io.trino.spi.exchange.ExchangeSourceStatistics; + +import java.util.Iterator; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicBoolean; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.Iterators.cycle; +import static com.google.common.collect.Iterators.limit; +import static com.google.common.collect.Sets.newConcurrentHashSet; +import static java.lang.Math.toIntExact; +import static java.util.Objects.requireNonNull; + +public class TestingExchange + implements Exchange +{ + private final boolean splitPartitionsEnabled; + + private final Set finishedSinks = newConcurrentHashSet(); + private final Set allSinks = newConcurrentHashSet(); + private final AtomicBoolean noMoreSinks = new AtomicBoolean(); + private final CompletableFuture> sourceHandles = new CompletableFuture<>(); + + public TestingExchange(boolean splitPartitionsEnabled) + { + this.splitPartitionsEnabled = splitPartitionsEnabled; + } + + @Override + public ExchangeSinkHandle addSink(int taskPartitionId) + { + TestingExchangeSinkHandle sinkHandle = new TestingExchangeSinkHandle(taskPartitionId); + allSinks.add(sinkHandle); + return sinkHandle; + } + + @Override + public void noMoreSinks() + { + noMoreSinks.set(true); + } + + public boolean isNoMoreSinks() + { + return noMoreSinks.get(); + } + + @Override + public ExchangeSinkInstanceHandle instantiateSink(ExchangeSinkHandle sinkHandle, int taskAttemptId) + { + return new TestingExchangeSinkInstanceHandle((TestingExchangeSinkHandle) sinkHandle, taskAttemptId); + } + + @Override + public void sinkFinished(ExchangeSinkInstanceHandle handle) + { + finishedSinks.add(((TestingExchangeSinkInstanceHandle) handle).getSinkHandle()); + } + + public Set getFinishedSinkHandles() + { + return ImmutableSet.copyOf(finishedSinks); + } + + @Override + public CompletableFuture> getSourceHandles() + { + return sourceHandles; + } + + public void setSourceHandles(List handles) + { + sourceHandles.complete(ImmutableList.copyOf(handles)); + } + + @Override + public ExchangeSourceSplitter split(ExchangeSourceHandle handle, long targetSizeInBytes) + { + List splitHandles = splitIntoList(handle, targetSizeInBytes); + Iterator iterator = splitHandles.iterator(); + return new ExchangeSourceSplitter() + { + @Override + public CompletableFuture isBlocked() + { + return NOT_BLOCKED; + } + + @Override + public Optional getNext() + { + if (iterator.hasNext()) { + return Optional.of(iterator.next()); + } + return Optional.empty(); + } + + @Override + public void close() + { + } + }; + } + + private List splitIntoList(ExchangeSourceHandle handle, long targetSizeInBytes) + { + if (!splitPartitionsEnabled) { + return ImmutableList.of(handle); + } + checkArgument(targetSizeInBytes > 0, "targetSizeInBytes must be positive: %s", targetSizeInBytes); + TestingExchangeSourceHandle testingExchangeSourceHandle = (TestingExchangeSourceHandle) handle; + long currentSize = testingExchangeSourceHandle.getSizeInBytes(); + int fullPartitions = toIntExact(currentSize / targetSizeInBytes); + long reminder = currentSize % targetSizeInBytes; + ImmutableList.Builder result = ImmutableList.builder(); + if (fullPartitions > 0) { + result.addAll(limit(cycle(new TestingExchangeSourceHandle(testingExchangeSourceHandle.getPartitionId(), targetSizeInBytes)), fullPartitions)); + } + if (reminder > 0) { + result.add(new TestingExchangeSourceHandle(testingExchangeSourceHandle.getPartitionId(), reminder)); + } + return result.build(); + } + + @Override + public ExchangeSourceStatistics getExchangeSourceStatistics(ExchangeSourceHandle handle) + { + return new ExchangeSourceStatistics(((TestingExchangeSourceHandle) handle).getSizeInBytes()); + } + + @Override + public void close() + { + } + + public static class TestingExchangeSinkInstanceHandle + implements ExchangeSinkInstanceHandle + { + private final TestingExchangeSinkHandle sinkHandle; + private final int attemptId; + + public TestingExchangeSinkInstanceHandle(TestingExchangeSinkHandle sinkHandle, int attemptId) + { + this.sinkHandle = requireNonNull(sinkHandle, "sinkHandle is null"); + this.attemptId = attemptId; + } + + public TestingExchangeSinkHandle getSinkHandle() + { + return sinkHandle; + } + + public int getAttemptId() + { + return attemptId; + } + } + + public static class TestingExchangeSourceHandle + implements ExchangeSourceHandle + { + private final int partitionId; + private final long sizeInBytes; + + public TestingExchangeSourceHandle(int partitionId, long sizeInBytes) + { + this.partitionId = partitionId; + this.sizeInBytes = sizeInBytes; + } + + @Override + public int getPartitionId() + { + return partitionId; + } + + public long getSizeInBytes() + { + return sizeInBytes; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + TestingExchangeSourceHandle that = (TestingExchangeSourceHandle) o; + return partitionId == that.partitionId && sizeInBytes == that.sizeInBytes; + } + + @Override + public int hashCode() + { + return Objects.hash(partitionId, sizeInBytes); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("partitionId", partitionId) + .add("sizeInBytes", sizeInBytes) + .toString(); + } + } + + public static class TestingExchangeSinkHandle + implements ExchangeSinkHandle + { + private final int taskPartitionId; + + public TestingExchangeSinkHandle(int taskPartitionId) + { + this.taskPartitionId = taskPartitionId; + } + + public int getTaskPartitionId() + { + return taskPartitionId; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + TestingExchangeSinkHandle sinkHandle = (TestingExchangeSinkHandle) o; + return taskPartitionId == sinkHandle.taskPartitionId; + } + + @Override + public int hashCode() + { + return Objects.hash(taskPartitionId); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("taskPartitionId", taskPartitionId) + .toString(); + } + } +} diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingExchangeManager.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingExchangeManager.java new file mode 100644 index 000000000000..e3c0a00466d4 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingExchangeManager.java @@ -0,0 +1,53 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import io.trino.spi.exchange.Exchange; +import io.trino.spi.exchange.ExchangeContext; +import io.trino.spi.exchange.ExchangeManager; +import io.trino.spi.exchange.ExchangeSink; +import io.trino.spi.exchange.ExchangeSinkInstanceHandle; +import io.trino.spi.exchange.ExchangeSource; +import io.trino.spi.exchange.ExchangeSourceHandle; + +import java.util.List; + +public class TestingExchangeManager + implements ExchangeManager +{ + private final boolean splitPartitionsEnabled; + + public TestingExchangeManager(boolean splitPartitionsEnabled) + { + this.splitPartitionsEnabled = splitPartitionsEnabled; + } + + @Override + public Exchange create(ExchangeContext context, int outputPartitionCount) + { + return new TestingExchange(splitPartitionsEnabled); + } + + @Override + public ExchangeSink createSink(ExchangeSinkInstanceHandle handle) + { + throw new UnsupportedOperationException(); + } + + @Override + public ExchangeSource createSource(List handles) + { + throw new UnsupportedOperationException(); + } +} diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingNodeSelectorFactory.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingNodeSelectorFactory.java new file mode 100644 index 000000000000..1bdc610fecaf --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingNodeSelectorFactory.java @@ -0,0 +1,153 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import com.google.common.collect.ImmutableList; +import io.trino.Session; +import io.trino.connector.CatalogName; +import io.trino.execution.RemoteTask; +import io.trino.metadata.InternalNode; +import io.trino.metadata.Split; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Supplier; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; + +public class TestingNodeSelectorFactory + implements NodeSelectorFactory +{ + private final InternalNode currentNode; + private final Supplier>> nodesSupplier; + + public TestingNodeSelectorFactory(InternalNode currentNode, Supplier>> nodesSupplier) + { + this.currentNode = requireNonNull(currentNode, "currentNode is null"); + this.nodesSupplier = requireNonNull(nodesSupplier, "nodesSupplier is null"); + } + + @Override + public NodeSelector createNodeSelector(Session session, Optional catalogName) + { + return new TestingNodeSelector(currentNode, createNodesSupplierForCatalog(catalogName, nodesSupplier)); + } + + private static Supplier> createNodesSupplierForCatalog(Optional catalogNameOptional, Supplier>> nodesSupplier) + { + return () -> { + Map> allNodes = nodesSupplier.get(); + if (catalogNameOptional.isEmpty()) { + return ImmutableList.copyOf(allNodes.keySet()); + } + CatalogName catalogName = catalogNameOptional.get(); + return allNodes.entrySet().stream() + .filter(entry -> entry.getValue().contains(catalogName)) + .map(Map.Entry::getKey) + .collect(toImmutableList()); + }; + } + + public static class TestingNodeSupplier + implements Supplier>> + { + private final Map> nodes = new ConcurrentHashMap<>(); + + public static TestingNodeSupplier create() + { + return new TestingNodeSupplier(); + } + + public static TestingNodeSupplier create(Map> nodes) + { + TestingNodeSupplier testingNodeSupplier = new TestingNodeSupplier(); + nodes.forEach(testingNodeSupplier::addNode); + return testingNodeSupplier; + } + + private TestingNodeSupplier() {} + + public void addNode(InternalNode node, List catalogs) + { + nodes.put(node, catalogs); + } + + public void removeNode(InternalNode node) + { + nodes.remove(node); + } + + @Override + public Map> get() + { + return nodes; + } + } + + private static class TestingNodeSelector + implements NodeSelector + { + private final InternalNode currentNode; + private final Supplier> nodesSupplier; + + private TestingNodeSelector(InternalNode currentNode, Supplier> nodesSupplier) + { + this.currentNode = requireNonNull(currentNode, "currentNode is null"); + this.nodesSupplier = requireNonNull(nodesSupplier, "nodesSupplier is null"); + } + + @Override + public void lockDownNodes() + { + throw new UnsupportedOperationException(); + } + + @Override + public List allNodes() + { + return ImmutableList.copyOf(nodesSupplier.get()); + } + + @Override + public InternalNode selectCurrentNode() + { + return currentNode; + } + + @Override + public List selectRandomNodes(int limit, Set excludedNodes) + { + return allNodes().stream() + .filter(node -> !excludedNodes.contains(node)) + .limit(limit) + .collect(toImmutableList()); + } + + @Override + public SplitPlacementResult computeAssignments(Set splits, List existingTasks) + { + throw new UnsupportedOperationException(); + } + + @Override + public SplitPlacementResult computeAssignments(Set splits, List existingTasks, BucketNodeMap bucketNodeMap) + { + throw new UnsupportedOperationException(); + } + } +} diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingSplitSource.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingSplitSource.java new file mode 100644 index 000000000000..f2d30c7a65fa --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingSplitSource.java @@ -0,0 +1,81 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.ListenableFuture; +import io.trino.connector.CatalogName; +import io.trino.execution.Lifespan; +import io.trino.metadata.Split; +import io.trino.spi.connector.ConnectorPartitionHandle; +import io.trino.split.SplitSource; + +import java.util.Iterator; +import java.util.List; +import java.util.Optional; + +import static com.google.common.util.concurrent.Futures.immediateFuture; +import static java.util.Objects.requireNonNull; + +public class TestingSplitSource + implements SplitSource +{ + private final CatalogName catalogName; + private final Iterator splits; + + public TestingSplitSource(CatalogName catalogName, List splits) + { + this.catalogName = requireNonNull(catalogName, "catalogName is null"); + this.splits = ImmutableList.copyOf(requireNonNull(splits, "splits is null")).iterator(); + } + + @Override + public CatalogName getCatalogName() + { + return catalogName; + } + + @Override + public ListenableFuture getNextBatch(ConnectorPartitionHandle partitionHandle, Lifespan lifespan, int maxSize) + { + if (isFinished()) { + return immediateFuture(new SplitBatch(ImmutableList.of(), true)); + } + ImmutableList.Builder result = ImmutableList.builder(); + for (int i = 0; i < maxSize; i++) { + if (!splits.hasNext()) { + break; + } + result.add(splits.next()); + } + return immediateFuture(new SplitBatch(result.build(), isFinished())); + } + + @Override + public void close() + { + } + + @Override + public boolean isFinished() + { + return !splits.hasNext(); + } + + @Override + public Optional> getTableExecuteSplitsInfo() + { + return Optional.empty(); + } +} diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingTaskLifecycleListener.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingTaskLifecycleListener.java new file mode 100644 index 000000000000..6a5bd6e381bd --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingTaskLifecycleListener.java @@ -0,0 +1,58 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import com.google.common.collect.ArrayListMultimap; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Multimap; +import io.trino.execution.RemoteTask; +import io.trino.execution.TaskId; +import io.trino.sql.planner.plan.PlanFragmentId; + +import javax.annotation.concurrent.GuardedBy; + +import java.util.Set; + +import static com.google.common.collect.Sets.newConcurrentHashSet; + +public class TestingTaskLifecycleListener + implements TaskLifecycleListener +{ + @GuardedBy("this") + private final Multimap tasks = ArrayListMultimap.create(); + private final Set noMoreTasks = newConcurrentHashSet(); + + @Override + public synchronized void taskCreated(PlanFragmentId fragmentId, RemoteTask task) + { + tasks.put(fragmentId, task.getTaskId()); + } + + public synchronized Multimap getTasks() + { + return ImmutableListMultimap.copyOf(tasks); + } + + @Override + public void noMoreTasks(PlanFragmentId fragmentId) + { + noMoreTasks.add(fragmentId); + } + + public Set getNoMoreTasks() + { + return ImmutableSet.copyOf(noMoreTasks); + } +} diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingTaskSourceFactory.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingTaskSourceFactory.java new file mode 100644 index 000000000000..b4a5bd39d12d --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingTaskSourceFactory.java @@ -0,0 +1,152 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.Multimap; +import io.trino.Session; +import io.trino.connector.CatalogName; +import io.trino.metadata.Split; +import io.trino.spi.exchange.Exchange; +import io.trino.spi.exchange.ExchangeSourceHandle; +import io.trino.sql.planner.PlanFragment; +import io.trino.sql.planner.plan.PlanFragmentId; +import io.trino.sql.planner.plan.PlanNodeId; +import io.trino.sql.planner.plan.RemoteSourceNode; + +import java.util.Collection; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.LongConsumer; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.Iterables.getOnlyElement; +import static io.trino.sql.planner.plan.ExchangeNode.Type.REPLICATE; +import static java.util.Objects.requireNonNull; + +public class TestingTaskSourceFactory + implements TaskSourceFactory +{ + private final Optional catalog; + private final List splits; + private final int tasksPerBatch; + + public TestingTaskSourceFactory(Optional catalog, List splits, int tasksPerBatch) + { + this.catalog = requireNonNull(catalog, "catalog is null"); + this.splits = ImmutableList.copyOf(requireNonNull(splits, "splits is null")); + this.tasksPerBatch = tasksPerBatch; + } + + @Override + public TaskSource create( + Session session, + PlanFragment fragment, + Map sourceExchanges, + Multimap exchangeSourceHandles, + LongConsumer getSplitTimeRecorder, + Optional bucketToPartitionMap, + Optional bucketNodeMap) + { + List partitionedSources = fragment.getPartitionedSources(); + checkArgument(partitionedSources.size() == 1, "single partitioned source is expected"); + + return new TestingTaskSource( + catalog, + splits, + tasksPerBatch, + getOnlyElement(partitionedSources), + getHandlesForRemoteSources(fragment.getRemoteSourceNodes(), exchangeSourceHandles)); + } + + private static Multimap getHandlesForRemoteSources( + List remoteSources, + Multimap exchangeSourceHandles) + { + ImmutableListMultimap.Builder result = ImmutableListMultimap.builder(); + for (RemoteSourceNode remoteSource : remoteSources) { + checkArgument(remoteSource.getExchangeType() == REPLICATE, "expected exchange type to be REPLICATE, got: %s", remoteSource.getExchangeType()); + for (PlanFragmentId fragmentId : remoteSource.getSourceFragmentIds()) { + Collection handles = requireNonNull(exchangeSourceHandles.get(fragmentId), () -> "exchange source handle is missing for fragment: " + fragmentId); + checkArgument(handles.size() == 1, "single exchange source handle is expected, got: %s", handles); + result.putAll(remoteSource.getId(), handles); + } + } + return result.build(); + } + + public static class TestingTaskSource + implements TaskSource + { + private final Optional catalogRequirement; + private final Iterator splits; + private final int tasksPerBatch; + private final PlanNodeId tableScanPlanNodeId; + private final Multimap exchangeSourceHandles; + + private final AtomicInteger nextPartitionId = new AtomicInteger(); + + public TestingTaskSource( + Optional catalogRequirement, + List splits, + int tasksPerBatch, + PlanNodeId tableScanPlanNodeId, + Multimap exchangeSourceHandles) + { + this.catalogRequirement = requireNonNull(catalogRequirement, "catalogRequirement is null"); + this.splits = ImmutableList.copyOf(requireNonNull(splits, "splits is null")).iterator(); + this.tasksPerBatch = tasksPerBatch; + this.tableScanPlanNodeId = requireNonNull(tableScanPlanNodeId, "tableScanPlanNodeId is null"); + this.exchangeSourceHandles = ImmutableListMultimap.copyOf(requireNonNull(exchangeSourceHandles, "exchangeSourceHandles is null")); + } + + @Override + public List getMoreTasks() + { + checkState(!isFinished(), "already finished"); + + ImmutableList.Builder result = ImmutableList.builder(); + for (int i = 0; i < tasksPerBatch; i++) { + if (isFinished()) { + break; + } + Split split = splits.next(); + TaskDescriptor task = new TaskDescriptor( + nextPartitionId.getAndIncrement(), + ImmutableListMultimap.of(tableScanPlanNodeId, split), + exchangeSourceHandles, + new NodeRequirements(catalogRequirement, Optional.empty())); + result.add(task); + } + + return result.build(); + } + + @Override + public boolean isFinished() + { + return !splits.hasNext(); + } + + @Override + public void close() + { + } + } +} diff --git a/core/trino-main/src/test/java/io/trino/memory/TestSystemMemoryBlocking.java b/core/trino-main/src/test/java/io/trino/memory/TestSystemMemoryBlocking.java index 3f249ff9a7fb..dea03b476702 100644 --- a/core/trino-main/src/test/java/io/trino/memory/TestSystemMemoryBlocking.java +++ b/core/trino-main/src/test/java/io/trino/memory/TestSystemMemoryBlocking.java @@ -21,7 +21,7 @@ import io.trino.connector.CatalogName; import io.trino.execution.Lifespan; import io.trino.execution.ScheduledSplit; -import io.trino.execution.TaskSource; +import io.trino.execution.SplitAssignment; import io.trino.metadata.Split; import io.trino.operator.Driver; import io.trino.operator.DriverContext; @@ -113,7 +113,7 @@ public void testTableScanSystemMemoryBlocking() assertSame(driver.getDriverContext(), driverContext); assertFalse(driver.isFinished()); Split testSplit = new Split(new CatalogName("test"), new TestSplit(), Lifespan.taskWide()); - driver.updateSource(new TaskSource(sourceId, ImmutableSet.of(new ScheduledSplit(0, sourceId, testSplit)), true)); + driver.updateSplitAssignment(new SplitAssignment(sourceId, ImmutableSet.of(new ScheduledSplit(0, sourceId, testSplit)), true)); ListenableFuture blocked = driver.processFor(new Duration(1, NANOSECONDS)); diff --git a/core/trino-main/src/test/java/io/trino/operator/TestDeduplicationExchangeClientBuffer.java b/core/trino-main/src/test/java/io/trino/operator/TestDeduplicatingDirectExchangeBuffer.java similarity index 71% rename from core/trino-main/src/test/java/io/trino/operator/TestDeduplicationExchangeClientBuffer.java rename to core/trino-main/src/test/java/io/trino/operator/TestDeduplicatingDirectExchangeBuffer.java index 4cbb70bc3f08..53d83b12e216 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestDeduplicationExchangeClientBuffer.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestDeduplicatingDirectExchangeBuffer.java @@ -43,7 +43,7 @@ import static org.testng.Assert.assertNotNull; import static org.testng.Assert.assertTrue; -public class TestDeduplicationExchangeClientBuffer +public class TestDeduplicatingDirectExchangeBuffer { private static final DataSize ONE_KB = DataSize.of(1, KILOBYTE); @@ -51,7 +51,7 @@ public class TestDeduplicationExchangeClientBuffer public void testIsBlocked() { // immediate close - try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + try (DirectExchangeBuffer buffer = new DeduplicatingDirectExchangeBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { ListenableFuture blocked = buffer.isBlocked(); assertBlocked(blocked); buffer.close(); @@ -59,7 +59,7 @@ public void testIsBlocked() } // empty set of tasks - try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + try (DirectExchangeBuffer buffer = new DeduplicatingDirectExchangeBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { ListenableFuture blocked = buffer.isBlocked(); assertBlocked(blocked); buffer.noMoreTasks(); @@ -67,7 +67,7 @@ public void testIsBlocked() } // single task finishes before noMoreTasks - try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + try (DirectExchangeBuffer buffer = new DeduplicatingDirectExchangeBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { ListenableFuture blocked = buffer.isBlocked(); assertBlocked(blocked); @@ -83,7 +83,7 @@ public void testIsBlocked() } // single task finishes after noMoreTasks - try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + try (DirectExchangeBuffer buffer = new DeduplicatingDirectExchangeBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { ListenableFuture blocked = buffer.isBlocked(); assertBlocked(blocked); @@ -99,7 +99,7 @@ public void testIsBlocked() } // single task fails before noMoreTasks - try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + try (DirectExchangeBuffer buffer = new DeduplicatingDirectExchangeBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { ListenableFuture blocked = buffer.isBlocked(); assertBlocked(blocked); @@ -115,7 +115,7 @@ public void testIsBlocked() } // single task fails after noMoreTasks - try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + try (DirectExchangeBuffer buffer = new DeduplicatingDirectExchangeBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { ListenableFuture blocked = buffer.isBlocked(); assertBlocked(blocked); @@ -131,7 +131,7 @@ public void testIsBlocked() } // cancelled blocked future doesn't affect other blocked futures - try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + try (DirectExchangeBuffer buffer = new DeduplicatingDirectExchangeBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { ListenableFuture blocked1 = buffer.isBlocked(); ListenableFuture blocked2 = buffer.isBlocked(); assertBlocked(blocked1); @@ -145,16 +145,18 @@ public void testIsBlocked() } @Test - public void testPollPage() + public void testPollPagesQueryLevelRetry() { - testPollPages(ImmutableListMultimap.of(), ImmutableMap.of(), ImmutableList.of()); + testPollPages(RetryPolicy.QUERY, ImmutableListMultimap.of(), ImmutableMap.of(), ImmutableList.of()); testPollPages( + RetryPolicy.QUERY, ImmutableListMultimap.builder() .put(createTaskId(0, 0), createPage("p0a0v0")) .build(), ImmutableMap.of(), ImmutableList.of("p0a0v0")); testPollPages( + RetryPolicy.QUERY, ImmutableListMultimap.builder() .put(createTaskId(0, 0), createPage("p0a0v0")) .put(createTaskId(0, 1), createPage("p0a1v0")) @@ -162,6 +164,7 @@ public void testPollPage() ImmutableMap.of(), ImmutableList.of("p0a1v0")); testPollPages( + RetryPolicy.QUERY, ImmutableListMultimap.builder() .put(createTaskId(0, 0), createPage("p0a0v0")) .put(createTaskId(1, 0), createPage("p1a0v0")) @@ -170,6 +173,7 @@ public void testPollPage() ImmutableMap.of(), ImmutableList.of("p0a1v0")); testPollPages( + RetryPolicy.QUERY, ImmutableListMultimap.builder() .put(createTaskId(0, 0), createPage("p0a0v0")) .put(createTaskId(1, 0), createPage("p1a0v0")) @@ -181,6 +185,7 @@ public void testPollPage() ImmutableList.of("p0a1v0")); RuntimeException error = new RuntimeException("error"); testPollPagesFailure( + RetryPolicy.QUERY, ImmutableListMultimap.builder() .put(createTaskId(0, 0), createPage("p0a0v0")) .put(createTaskId(1, 0), createPage("p1a0v0")) @@ -191,6 +196,7 @@ public void testPollPage() error), error); testPollPagesFailure( + RetryPolicy.QUERY, ImmutableListMultimap.builder() .put(createTaskId(0, 0), createPage("p0a0v0")) .put(createTaskId(1, 0), createPage("p1a0v0")) @@ -202,9 +208,80 @@ public void testPollPage() error); } - private static void testPollPages(Multimap pages, Map failures, List expectedValues) + @Test + public void testPollPagesTaskLevelRetry() + { + testPollPages(RetryPolicy.TASK, ImmutableListMultimap.of(), ImmutableMap.of(), ImmutableList.of()); + testPollPages( + RetryPolicy.TASK, + ImmutableListMultimap.of( + createTaskId(0, 0), + createPage("p0a0v0")), + ImmutableMap.of(), + ImmutableList.of("p0a0v0")); + testPollPages( + RetryPolicy.TASK, + ImmutableListMultimap.of( + createTaskId(0, 0), + createPage("p0a0v0"), + createTaskId(0, 1), + createPage("p0a1v0")), + ImmutableMap.of(), + ImmutableList.of("p0a0v0")); + testPollPages( + RetryPolicy.TASK, + ImmutableListMultimap.of( + createTaskId(0, 0), + createPage("p0a0v0"), + createTaskId(1, 0), + createPage("p1a0v0"), + createTaskId(0, 1), + createPage("p0a1v0")), + ImmutableMap.of(), + ImmutableList.of("p0a0v0", "p1a0v0")); + testPollPages( + RetryPolicy.TASK, + ImmutableListMultimap.of( + createTaskId(0, 0), + createPage("p0a0v0"), + createTaskId(1, 1), + createPage("p1a0v0"), + createTaskId(0, 1), + createPage("p0a1v0")), + ImmutableMap.of( + createTaskId(1, 0), + new RuntimeException("error")), + ImmutableList.of("p0a0v0", "p1a0v0")); + RuntimeException error = new RuntimeException("error"); + testPollPagesFailure( + RetryPolicy.TASK, + ImmutableListMultimap.of( + createTaskId(0, 0), + createPage("p0a0v0"), + createTaskId(1, 0), + createPage("p1a0v0"), + createTaskId(0, 1), + createPage("p0a1v0")), + ImmutableMap.of( + createTaskId(2, 2), + error), + error); + testPollPagesFailure( + RetryPolicy.TASK, + ImmutableListMultimap.of( + createTaskId(1, 0), + createPage("p1a0v0"), + createTaskId(0, 1), + createPage("p0a1v0")), + ImmutableMap.of( + createTaskId(0, 1), + error), + error); + } + + private static void testPollPages(RetryPolicy retryPolicy, Multimap pages, Map failures, List expectedValues) { - List actualPages = pollPages(pages, failures); + List actualPages = pollPages(retryPolicy, pages, failures); List actualValues = actualPages.stream() .map(SerializedPage::getSlice) .map(Slice::toStringUtf8) @@ -212,14 +289,14 @@ private static void testPollPages(Multimap pages, Map pages, Map failures, Throwable expectedFailure) + private static void testPollPagesFailure(RetryPolicy retryPolicy, Multimap pages, Map failures, Throwable expectedFailure) { - assertThatThrownBy(() -> pollPages(pages, failures)).isEqualTo(expectedFailure); + assertThatThrownBy(() -> pollPages(retryPolicy, pages, failures)).isEqualTo(expectedFailure); } - private static List pollPages(Multimap pages, Map failures) + private static List pollPages(RetryPolicy retryPolicy, Multimap pages, Map failures) { - try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + try (DirectExchangeBuffer buffer = new DeduplicatingDirectExchangeBuffer(directExecutor(), ONE_KB, retryPolicy)) { for (TaskId taskId : Sets.union(pages.keySet(), failures.keySet())) { buffer.addTask(taskId); } @@ -250,7 +327,7 @@ private static List pollPages(Multimap p @Test public void testRemovePagesForPreviousAttempts() { - try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + try (DirectExchangeBuffer buffer = new DeduplicatingDirectExchangeBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { assertEquals(buffer.getRetainedSizeInBytes(), 0); TaskId partition0Attempt0 = createTaskId(0, 0); @@ -280,7 +357,7 @@ public void testRemovePagesForPreviousAttempts() @Test public void testBufferOverflow() { - try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), DataSize.of(100, BYTE), RetryPolicy.QUERY)) { + try (DirectExchangeBuffer buffer = new DeduplicatingDirectExchangeBuffer(directExecutor(), DataSize.of(100, BYTE), RetryPolicy.QUERY)) { TaskId task = createTaskId(0, 0); SerializedPage page1 = createPage("1234"); @@ -311,21 +388,21 @@ public void testBufferOverflow() public void testIsFinished() { // close right away - try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + try (DirectExchangeBuffer buffer = new DeduplicatingDirectExchangeBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { assertFalse(buffer.isFinished()); buffer.close(); assertTrue(buffer.isFinished()); } // 0 tasks - try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + try (DirectExchangeBuffer buffer = new DeduplicatingDirectExchangeBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { assertFalse(buffer.isFinished()); buffer.noMoreTasks(); assertTrue(buffer.isFinished()); } // single task producing no results, finish before noMoreTasks - try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + try (DirectExchangeBuffer buffer = new DeduplicatingDirectExchangeBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { assertFalse(buffer.isFinished()); TaskId taskId = createTaskId(0, 0); @@ -340,7 +417,7 @@ public void testIsFinished() } // single task producing no results, finish after noMoreTasks - try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + try (DirectExchangeBuffer buffer = new DeduplicatingDirectExchangeBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { assertFalse(buffer.isFinished()); TaskId taskId = createTaskId(0, 0); @@ -355,7 +432,7 @@ public void testIsFinished() } // single task producing no results, fail before noMoreTasks - try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + try (DirectExchangeBuffer buffer = new DeduplicatingDirectExchangeBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { assertFalse(buffer.isFinished()); TaskId taskId = createTaskId(0, 0); @@ -370,7 +447,7 @@ public void testIsFinished() } // single task producing no results, fail after noMoreTasks - try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + try (DirectExchangeBuffer buffer = new DeduplicatingDirectExchangeBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { assertFalse(buffer.isFinished()); TaskId taskId = createTaskId(0, 0); @@ -385,7 +462,7 @@ public void testIsFinished() } // single task producing one page, fail after noMoreTasks - try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + try (DirectExchangeBuffer buffer = new DeduplicatingDirectExchangeBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { assertFalse(buffer.isFinished()); TaskId taskId = createTaskId(0, 0); @@ -401,7 +478,7 @@ public void testIsFinished() } // single task producing one page, finish after noMoreTasks - try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + try (DirectExchangeBuffer buffer = new DeduplicatingDirectExchangeBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { assertFalse(buffer.isFinished()); TaskId taskId = createTaskId(0, 0); @@ -420,7 +497,7 @@ public void testIsFinished() } // single task producing one page, finish before noMoreTasks - try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + try (DirectExchangeBuffer buffer = new DeduplicatingDirectExchangeBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { assertFalse(buffer.isFinished()); TaskId taskId = createTaskId(0, 0); @@ -442,7 +519,7 @@ public void testIsFinished() @Test public void testRemainingBufferCapacity() { - try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { + try (DirectExchangeBuffer buffer = new DeduplicatingDirectExchangeBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) { assertFalse(buffer.isFinished()); TaskId taskId = createTaskId(0, 0); diff --git a/core/trino-main/src/test/java/io/trino/operator/TestExchangeClient.java b/core/trino-main/src/test/java/io/trino/operator/TestDirectExchangeClient.java similarity index 94% rename from core/trino-main/src/test/java/io/trino/operator/TestExchangeClient.java rename to core/trino-main/src/test/java/io/trino/operator/TestDirectExchangeClient.java index 430aa6c89e1b..751505d7fac4 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestExchangeClient.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestDirectExchangeClient.java @@ -83,7 +83,7 @@ import static org.testng.Assert.assertTrue; @Test(singleThreaded = true) -public class TestExchangeClient +public class TestDirectExchangeClient { private ScheduledExecutorService scheduler; private ExecutorService pageBufferClientCallbackExecutor; @@ -123,10 +123,10 @@ public void testHappyPath() pages.forEach(page -> processor.addPage(location, page)); processor.setComplete(location); - TestingExchangeClientBuffer buffer = new TestingExchangeClientBuffer(DataSize.of(1, Unit.MEGABYTE)); + TestingDirectExchangeBuffer buffer = new TestingDirectExchangeBuffer(DataSize.of(1, Unit.MEGABYTE)); @SuppressWarnings("resource") - ExchangeClient exchangeClient = new ExchangeClient( + DirectExchangeClient exchangeClient = new DirectExchangeClient( "localhost", DataIntegrityVerification.ABORT, buffer, @@ -161,7 +161,7 @@ public void testHappyPath() buffer.setFinished(true); assertTrue(exchangeClient.isFinished()); - ExchangeClientStatus status = exchangeClient.getStatus(); + DirectExchangeClientStatus status = exchangeClient.getStatus(); assertEquals(status.getBufferedPages(), 0); // client should have sent only 3 requests: one to get all pages, one to acknowledge and one to get the done signal @@ -189,10 +189,10 @@ public void testStreamingHappyPath() processor.setComplete(location); @SuppressWarnings("resource") - ExchangeClient exchangeClient = new ExchangeClient( + DirectExchangeClient exchangeClient = new DirectExchangeClient( "localhost", DataIntegrityVerification.ABORT, - new StreamingExchangeClientBuffer(scheduler, DataSize.of(32, Unit.MEGABYTE)), + new StreamingDirectExchangeBuffer(scheduler, DataSize.of(32, Unit.MEGABYTE)), maxResponseSize, 1, new Duration(1, TimeUnit.MINUTES), @@ -215,7 +215,7 @@ public void testStreamingHappyPath() assertNull(getNextPage(exchangeClient)); assertTrue(exchangeClient.isFinished()); - ExchangeClientStatus status = exchangeClient.getStatus(); + DirectExchangeClientStatus status = exchangeClient.getStatus(); assertEquals(status.getBufferedPages(), 0); // client should have sent only 3 requests: one to get all pages, one to acknowledge and one to get the done signal @@ -242,10 +242,10 @@ public void testAddLocation() processor.addPage(location1, createSerializedPage("location-1-page-1")); processor.addPage(location1, createSerializedPage("location-1-page-2")); - TestingExchangeClientBuffer buffer = new TestingExchangeClientBuffer(DataSize.of(1, Unit.MEGABYTE)); + TestingDirectExchangeBuffer buffer = new TestingDirectExchangeBuffer(DataSize.of(1, Unit.MEGABYTE)); @SuppressWarnings("resource") - ExchangeClient exchangeClient = new ExchangeClient( + DirectExchangeClient exchangeClient = new DirectExchangeClient( "localhost", DataIntegrityVerification.ABORT, buffer, @@ -313,10 +313,10 @@ public void testStreamingAddLocation() MockExchangeRequestProcessor processor = new MockExchangeRequestProcessor(maxResponseSize); @SuppressWarnings("resource") - ExchangeClient exchangeClient = new ExchangeClient( + DirectExchangeClient exchangeClient = new DirectExchangeClient( "localhost", DataIntegrityVerification.ABORT, - new StreamingExchangeClientBuffer(scheduler, DataSize.of(32, Unit.MEGABYTE)), + new StreamingDirectExchangeBuffer(scheduler, DataSize.of(32, Unit.MEGABYTE)), maxResponseSize, 1, new Duration(1, TimeUnit.MINUTES), @@ -420,10 +420,10 @@ public void testDeduplication() processor.addPage(locationP0A1, createSerializedPage("location-2-page-2")); @SuppressWarnings("resource") - ExchangeClient exchangeClient = new ExchangeClient( + DirectExchangeClient exchangeClient = new DirectExchangeClient( "localhost", DataIntegrityVerification.ABORT, - new DeduplicationExchangeClientBuffer(scheduler, DataSize.of(1, Unit.KILOBYTE), RetryPolicy.QUERY), + new DeduplicatingDirectExchangeBuffer(scheduler, DataSize.of(1, Unit.KILOBYTE), RetryPolicy.QUERY), maxResponseSize, 1, new Duration(1, SECONDS), @@ -492,13 +492,13 @@ public void testTaskFailure() processor.addPage(location4, createSerializedPage("location-4-page-1")); processor.addPage(location4, createSerializedPage("location-4-page-2")); - TestingExchangeClientBuffer buffer = new TestingExchangeClientBuffer(DataSize.of(1, Unit.MEGABYTE)); + TestingDirectExchangeBuffer buffer = new TestingDirectExchangeBuffer(DataSize.of(1, Unit.MEGABYTE)); Set failedTasks = newConcurrentHashSet(); CountDownLatch latch = new CountDownLatch(2); @SuppressWarnings("resource") - ExchangeClient exchangeClient = new ExchangeClient( + DirectExchangeClient exchangeClient = new DirectExchangeClient( "localhost", DataIntegrityVerification.ABORT, buffer, @@ -589,7 +589,7 @@ public void testTaskFailure() assertTrue(exchangeClient.isFinished()); } - private static void assertTaskIsNotFinished(TestingExchangeClientBuffer buffer, TaskId task) + private static void assertTaskIsNotFinished(TestingDirectExchangeBuffer buffer, TaskId task) { assertThatThrownBy(() -> buffer.whenTaskFinished(task).get(50, MILLISECONDS)) .isInstanceOf(TimeoutException.class); @@ -610,10 +610,10 @@ public void testStreamingBufferLimit() processor.setComplete(location); @SuppressWarnings("resource") - ExchangeClient exchangeClient = new ExchangeClient( + DirectExchangeClient exchangeClient = new DirectExchangeClient( "localhost", DataIntegrityVerification.ABORT, - new StreamingExchangeClientBuffer(scheduler, DataSize.ofBytes(1)), + new StreamingDirectExchangeBuffer(scheduler, DataSize.ofBytes(1)), maxResponseSize, 1, new Duration(1, TimeUnit.MINUTES), @@ -675,8 +675,7 @@ public void testStreamingBufferLimit() // wait for client to decide there are no more pages assertNull(getNextPage(exchangeClient)); assertEquals(exchangeClient.getStatus().getBufferedPages(), 0); - assertTrue(exchangeClient.getStatus().getBufferedBytes() == 0); - assertEquals(exchangeClient.isFinished(), true); + assertTrue(exchangeClient.isFinished()); exchangeClient.close(); assertStatus(exchangeClient.getStatus().getPageBufferClientStatuses().get(0), location, "closed", 3, 5, 5, "not scheduled"); } @@ -685,7 +684,7 @@ public void testStreamingBufferLimit() public void testStreamingAbortOnDataCorruption() { URI location = URI.create("http://localhost:8080"); - ExchangeClient exchangeClient = setUpDataCorruption(DataIntegrityVerification.ABORT, location); + DirectExchangeClient exchangeClient = setUpDataCorruption(DataIntegrityVerification.ABORT, location); assertThatThrownBy(() -> getNextPage(exchangeClient)) .isInstanceOf(TrinoException.class) @@ -698,7 +697,7 @@ public void testStreamingAbortOnDataCorruption() public void testStreamingRetryDataCorruption() { URI location = URI.create("http://localhost:8080"); - ExchangeClient exchangeClient = setUpDataCorruption(DataIntegrityVerification.RETRY, location); + DirectExchangeClient exchangeClient = setUpDataCorruption(DataIntegrityVerification.RETRY, location); assertFalse(exchangeClient.isFinished()); assertPageEquals(getNextPage(exchangeClient), createPage(1)); @@ -708,14 +707,14 @@ public void testStreamingRetryDataCorruption() assertTrue(exchangeClient.isFinished()); exchangeClient.close(); - ExchangeClientStatus status = exchangeClient.getStatus(); + DirectExchangeClientStatus status = exchangeClient.getStatus(); assertEquals(status.getBufferedPages(), 0); assertEquals(status.getBufferedBytes(), 0); assertStatus(status.getPageBufferClientStatuses().get(0), location, "closed", 2, 4, 4, "not scheduled"); } - private ExchangeClient setUpDataCorruption(DataIntegrityVerification dataIntegrityVerification, URI location) + private DirectExchangeClient setUpDataCorruption(DataIntegrityVerification dataIntegrityVerification, URI location) { DataSize maxResponseSize = DataSize.of(10, Unit.MEGABYTE); @@ -761,10 +760,10 @@ public synchronized Response handle(Request request) } }; - ExchangeClient exchangeClient = new ExchangeClient( + DirectExchangeClient exchangeClient = new DirectExchangeClient( "localhost", dataIntegrityVerification, - new StreamingExchangeClientBuffer(scheduler, DataSize.of(32, Unit.MEGABYTE)), + new StreamingDirectExchangeBuffer(scheduler, DataSize.of(32, Unit.MEGABYTE)), maxResponseSize, 1, new Duration(1, TimeUnit.MINUTES), @@ -794,10 +793,10 @@ public void testStreamingClose() processor.addPage(location, createPage(3)); @SuppressWarnings("resource") - ExchangeClient exchangeClient = new ExchangeClient( + DirectExchangeClient exchangeClient = new DirectExchangeClient( "localhost", DataIntegrityVerification.ABORT, - new StreamingExchangeClientBuffer(scheduler, DataSize.ofBytes(1)), + new StreamingDirectExchangeBuffer(scheduler, DataSize.ofBytes(1)), maxResponseSize, 1, new Duration(1, TimeUnit.MINUTES), @@ -839,7 +838,7 @@ private static SerializedPage createSerializedPage(String value) return new SerializedPage(utf8Slice(value), PageCodecMarker.MarkerSet.empty(), 1, value.length()); } - private static SerializedPage getNextPage(ExchangeClient exchangeClient) + private static SerializedPage getNextPage(DirectExchangeClient exchangeClient) { ListenableFuture futurePage = Futures.transform(exchangeClient.isBlocked(), ignored -> exchangeClient.pollPage(), directExecutor()); return tryGetFutureValue(futurePage, 100, TimeUnit.SECONDS).orElse(null); diff --git a/core/trino-main/src/test/java/io/trino/operator/TestExchangeClientConfig.java b/core/trino-main/src/test/java/io/trino/operator/TestDirectExchangeClientConfig.java similarity index 93% rename from core/trino-main/src/test/java/io/trino/operator/TestExchangeClientConfig.java rename to core/trino-main/src/test/java/io/trino/operator/TestDirectExchangeClientConfig.java index 07f2f0ac63b6..92f446eb1339 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestExchangeClientConfig.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestDirectExchangeClientConfig.java @@ -27,12 +27,12 @@ import static io.airlift.configuration.testing.ConfigAssertions.recordDefaults; import static io.airlift.units.DataSize.Unit; -public class TestExchangeClientConfig +public class TestDirectExchangeClientConfig { @Test public void testDefaults() { - assertRecordedDefaults(recordDefaults(ExchangeClientConfig.class) + assertRecordedDefaults(recordDefaults(DirectExchangeClientConfig.class) .setMaxBufferSize(DataSize.of(32, Unit.MEGABYTE)) .setConcurrentRequestMultiplier(3) .setMinErrorDuration(new Duration(5, TimeUnit.MINUTES)) @@ -57,7 +57,7 @@ public void testExplicitPropertyMappings() .put("exchange.acknowledge-pages", "false") .build(); - ExchangeClientConfig expected = new ExchangeClientConfig() + DirectExchangeClientConfig expected = new DirectExchangeClientConfig() .setMaxBufferSize(DataSize.of(1, Unit.GIGABYTE)) .setConcurrentRequestMultiplier(13) .setMinErrorDuration(new Duration(33, TimeUnit.SECONDS)) diff --git a/core/trino-main/src/test/java/io/trino/operator/TestDriver.java b/core/trino-main/src/test/java/io/trino/operator/TestDriver.java index c2549c7f705b..7bee6fb760e5 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestDriver.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestDriver.java @@ -22,7 +22,7 @@ import io.trino.connector.CatalogName; import io.trino.execution.Lifespan; import io.trino.execution.ScheduledSplit; -import io.trino.execution.TaskSource; +import io.trino.execution.SplitAssignment; import io.trino.memory.context.LocalMemoryContext; import io.trino.metadata.Split; import io.trino.metadata.TableHandle; @@ -182,7 +182,7 @@ public void testAddSourceFinish() assertFalse(driver.processFor(new Duration(1, TimeUnit.MILLISECONDS)).isDone()); assertFalse(driver.isFinished()); - driver.updateSource(new TaskSource(sourceId, ImmutableSet.of(new ScheduledSplit(0, sourceId, newMockSplit())), true)); + driver.updateSplitAssignment(new SplitAssignment(sourceId, ImmutableSet.of(new ScheduledSplit(0, sourceId, newMockSplit())), true)); assertFalse(driver.isFinished()); assertTrue(driver.processFor(new Duration(1, TimeUnit.SECONDS)).isDone()); @@ -285,7 +285,7 @@ public void testBrokenOperatorAddSource() assertTrue(driver.processFor(new Duration(1, TimeUnit.MILLISECONDS)).isDone()); assertFalse(driver.isFinished()); - driver.updateSource(new TaskSource(sourceId, ImmutableSet.of(new ScheduledSplit(0, sourceId, newMockSplit())), true)); + driver.updateSplitAssignment(new SplitAssignment(sourceId, ImmutableSet.of(new ScheduledSplit(0, sourceId, newMockSplit())), true)); assertFalse(driver.isFinished()); // processFor always returns NOT_BLOCKED, because DriveLockResult was not acquired diff --git a/core/trino-main/src/test/java/io/trino/operator/TestExchangeOperator.java b/core/trino-main/src/test/java/io/trino/operator/TestExchangeOperator.java index 19057f243494..4bf86f31b54b 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestExchangeOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestExchangeOperator.java @@ -22,16 +22,19 @@ import io.airlift.units.DataSize; import io.airlift.units.Duration; import io.trino.FeaturesConfig.DataIntegrityVerification; +import io.trino.exchange.ExchangeManagerRegistry; import io.trino.execution.Lifespan; import io.trino.execution.StageId; import io.trino.execution.TaskId; import io.trino.execution.buffer.PagesSerdeFactory; import io.trino.execution.buffer.TestingPagesSerdeFactory; +import io.trino.metadata.HandleResolver; import io.trino.metadata.Split; import io.trino.operator.ExchangeOperator.ExchangeOperatorFactory; import io.trino.spi.Page; import io.trino.spi.type.Type; import io.trino.split.RemoteSplit; +import io.trino.split.RemoteSplit.DirectExchangeInput; import io.trino.sql.planner.plan.PlanNodeId; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; @@ -75,7 +78,7 @@ public class TestExchangeOperator private ScheduledExecutorService scheduler; private ScheduledExecutorService scheduledExecutor; private HttpClient httpClient; - private ExchangeClientSupplier exchangeClientSupplier; + private DirectExchangeClientSupplier directExchangeClientSupplier; private ExecutorService pageBufferClientCallbackExecutor; @SuppressWarnings("resource") @@ -87,10 +90,10 @@ public void setUp() pageBufferClientCallbackExecutor = Executors.newSingleThreadExecutor(); httpClient = new TestingHttpClient(new TestingExchangeHttpClientHandler(taskBuffers), scheduler); - exchangeClientSupplier = (systemMemoryUsageListener, taskFailureListener, retryPolicy) -> new ExchangeClient( + directExchangeClientSupplier = (systemMemoryUsageListener, taskFailureListener, retryPolicy) -> new DirectExchangeClient( "localhost", DataIntegrityVerification.ABORT, - new StreamingExchangeClientBuffer(scheduler, DataSize.of(32, MEGABYTE)), + new StreamingDirectExchangeBuffer(scheduler, DataSize.of(32, MEGABYTE)), DataSize.of(10, MEGABYTE), 3, new Duration(1, TimeUnit.MINUTES), @@ -149,7 +152,7 @@ public void testSimple() private static Split newRemoteSplit(TaskId taskId) { - return new Split(REMOTE_CONNECTOR_ID, new RemoteSplit(taskId, URI.create("http://localhost/" + taskId)), Lifespan.taskWide()); + return new Split(REMOTE_CONNECTOR_ID, new RemoteSplit(new DirectExchangeInput(taskId, URI.create("http://localhost/" + taskId))), Lifespan.taskWide()); } @Test @@ -254,7 +257,7 @@ public void testFinish() private SourceOperator createExchangeOperator() { - ExchangeOperatorFactory operatorFactory = new ExchangeOperatorFactory(0, new PlanNodeId("test"), exchangeClientSupplier, SERDE_FACTORY, RetryPolicy.NONE); + ExchangeOperatorFactory operatorFactory = new ExchangeOperatorFactory(0, new PlanNodeId("test"), directExchangeClientSupplier, SERDE_FACTORY, RetryPolicy.NONE, new ExchangeManagerRegistry(new HandleResolver())); DriverContext driverContext = createTaskContext(scheduler, scheduledExecutor, TEST_SESSION) .addPipelineContext(0, true, true, false) diff --git a/core/trino-main/src/test/java/io/trino/operator/TestMergeOperator.java b/core/trino-main/src/test/java/io/trino/operator/TestMergeOperator.java index d80e354aa7bb..1c5a6a42a4cb 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestMergeOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestMergeOperator.java @@ -32,6 +32,7 @@ import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; import io.trino.split.RemoteSplit; +import io.trino.split.RemoteSplit.DirectExchangeInput; import io.trino.sql.gen.OrderingCompiler; import io.trino.sql.planner.plan.PlanNodeId; import org.testng.annotations.AfterMethod; @@ -74,7 +75,7 @@ public class TestMergeOperator private ScheduledExecutorService executor; private PagesSerdeFactory serdeFactory; private HttpClient httpClient; - private ExchangeClientFactory exchangeClientFactory; + private DirectExchangeClientFactory exchangeClientFactory; private OrderingCompiler orderingCompiler; private LoadingCache taskBuffers; @@ -87,7 +88,7 @@ public void setUp() taskBuffers = CacheBuilder.newBuilder().build(CacheLoader.from(TestingTaskBuffer::new)); httpClient = new TestingHttpClient(new TestingExchangeHttpClientHandler(taskBuffers), executor); - exchangeClientFactory = new ExchangeClientFactory(new NodeInfo("test"), new FeaturesConfig(), new ExchangeClientConfig(), httpClient, executor); + exchangeClientFactory = new DirectExchangeClientFactory(new NodeInfo("test"), new FeaturesConfig(), new DirectExchangeClientConfig(), httpClient, executor); orderingCompiler = new OrderingCompiler(new TypeOperators()); } @@ -355,7 +356,7 @@ private MergeOperator createMergeOperator(List sourceTypes, List private static Split createRemoteSplit(TaskId taskId) { - return new Split(ExchangeOperator.REMOTE_CONNECTOR_ID, new RemoteSplit(taskId, URI.create("http://localhost/" + taskId)), Lifespan.taskWide()); + return new Split(ExchangeOperator.REMOTE_CONNECTOR_ID, new RemoteSplit(new DirectExchangeInput(taskId, URI.create("http://localhost/" + taskId))), Lifespan.taskWide()); } private static List pullAvailablePages(Operator operator) diff --git a/core/trino-main/src/test/java/io/trino/operator/TestStreamingExchangeClientBuffer.java b/core/trino-main/src/test/java/io/trino/operator/TestStreamingDirectExchangeBuffer.java similarity index 90% rename from core/trino-main/src/test/java/io/trino/operator/TestStreamingExchangeClientBuffer.java rename to core/trino-main/src/test/java/io/trino/operator/TestStreamingDirectExchangeBuffer.java index 5b139a3e8ae0..561cb5ed1043 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestStreamingExchangeClientBuffer.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestStreamingDirectExchangeBuffer.java @@ -32,7 +32,7 @@ import static org.testng.Assert.assertNull; import static org.testng.Assert.assertTrue; -public class TestStreamingExchangeClientBuffer +public class TestStreamingDirectExchangeBuffer { private static final StageId STAGE_ID = new StageId(new QueryId("query"), 0); private static final TaskId TASK_0 = new TaskId(STAGE_ID, 0, 0); @@ -44,7 +44,7 @@ public class TestStreamingExchangeClientBuffer @Test public void testHappyPath() { - try (StreamingExchangeClientBuffer buffer = new StreamingExchangeClientBuffer(directExecutor(), DataSize.of(1, KILOBYTE))) { + try (StreamingDirectExchangeBuffer buffer = new StreamingDirectExchangeBuffer(directExecutor(), DataSize.of(1, KILOBYTE))) { assertFalse(buffer.isFinished()); assertFalse(buffer.isBlocked().isDone()); assertNull(buffer.pollPage()); @@ -106,7 +106,7 @@ public void testHappyPath() @Test public void testClose() { - StreamingExchangeClientBuffer buffer = new StreamingExchangeClientBuffer(directExecutor(), DataSize.of(1, KILOBYTE)); + StreamingDirectExchangeBuffer buffer = new StreamingDirectExchangeBuffer(directExecutor(), DataSize.of(1, KILOBYTE)); buffer.addTask(TASK_0); buffer.addTask(TASK_1); @@ -125,7 +125,7 @@ public void testClose() public void testIsFinished() { // 0 tasks - try (StreamingExchangeClientBuffer buffer = new StreamingExchangeClientBuffer(directExecutor(), DataSize.of(1, KILOBYTE))) { + try (StreamingDirectExchangeBuffer buffer = new StreamingDirectExchangeBuffer(directExecutor(), DataSize.of(1, KILOBYTE))) { assertFalse(buffer.isFinished()); assertFalse(buffer.isBlocked().isDone()); @@ -136,7 +136,7 @@ public void testIsFinished() } // single task - try (StreamingExchangeClientBuffer buffer = new StreamingExchangeClientBuffer(directExecutor(), DataSize.of(1, KILOBYTE))) { + try (StreamingDirectExchangeBuffer buffer = new StreamingDirectExchangeBuffer(directExecutor(), DataSize.of(1, KILOBYTE))) { assertFalse(buffer.isFinished()); assertFalse(buffer.isBlocked().isDone()); @@ -153,7 +153,7 @@ public void testIsFinished() } // single failed task - try (StreamingExchangeClientBuffer buffer = new StreamingExchangeClientBuffer(directExecutor(), DataSize.of(1, KILOBYTE))) { + try (StreamingDirectExchangeBuffer buffer = new StreamingDirectExchangeBuffer(directExecutor(), DataSize.of(1, KILOBYTE))) { assertFalse(buffer.isFinished()); assertFalse(buffer.isBlocked().isDone()); @@ -174,7 +174,7 @@ public void testIsFinished() @Test public void testFutureCancellationDoesNotAffectOtherFutures() { - try (StreamingExchangeClientBuffer buffer = new StreamingExchangeClientBuffer(directExecutor(), DataSize.of(1, KILOBYTE))) { + try (StreamingDirectExchangeBuffer buffer = new StreamingDirectExchangeBuffer(directExecutor(), DataSize.of(1, KILOBYTE))) { assertFalse(buffer.isFinished()); ListenableFuture blocked1 = buffer.isBlocked(); diff --git a/core/trino-main/src/test/java/io/trino/operator/TestingExchangeClientBuffer.java b/core/trino-main/src/test/java/io/trino/operator/TestingDirectExchangeBuffer.java similarity index 97% rename from core/trino-main/src/test/java/io/trino/operator/TestingExchangeClientBuffer.java rename to core/trino-main/src/test/java/io/trino/operator/TestingDirectExchangeBuffer.java index c26891bc9722..73ddacf256ed 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestingExchangeClientBuffer.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestingDirectExchangeBuffer.java @@ -33,8 +33,8 @@ import static com.google.common.util.concurrent.Futures.immediateVoidFuture; import static java.util.Objects.requireNonNull; -public class TestingExchangeClientBuffer - implements ExchangeClientBuffer +public class TestingDirectExchangeBuffer + implements DirectExchangeBuffer { private ListenableFuture blocked = immediateVoidFuture(); private final Set allTasks = new HashSet<>(); @@ -48,7 +48,7 @@ public class TestingExchangeClientBuffer private final Map> taskFinished = new HashMap<>(); private final Map> taskFailed = new HashMap<>(); - public TestingExchangeClientBuffer(DataSize bufferCapacity) + public TestingDirectExchangeBuffer(DataSize bufferCapacity) { this.remainingBufferCapacityInBytes = bufferCapacity.toBytes(); } diff --git a/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java b/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java index 03f284d01c1d..77f56ad6fc9d 100644 --- a/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java +++ b/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java @@ -37,11 +37,11 @@ import io.trino.execution.NodeTaskMap; import io.trino.execution.QueryManagerConfig; import io.trino.execution.RemoteTask; +import io.trino.execution.SplitAssignment; import io.trino.execution.StageId; import io.trino.execution.TaskId; import io.trino.execution.TaskInfo; import io.trino.execution.TaskManagerConfig; -import io.trino.execution.TaskSource; import io.trino.execution.TaskState; import io.trino.execution.TaskStatus; import io.trino.execution.TaskTestUtils; @@ -181,14 +181,14 @@ public void testRegular() Lifespan lifespan = Lifespan.driverGroup(3); remoteTask.addSplits(ImmutableMultimap.of(TABLE_SCAN_NODE_ID, new Split(new CatalogName("test"), TestingSplit.createLocalSplit(), lifespan))); - poll(() -> testingTaskResource.getTaskSource(TABLE_SCAN_NODE_ID) != null); - poll(() -> testingTaskResource.getTaskSource(TABLE_SCAN_NODE_ID).getSplits().size() == 1); + poll(() -> testingTaskResource.getTaskSplitAssignment(TABLE_SCAN_NODE_ID) != null); + poll(() -> testingTaskResource.getTaskSplitAssignment(TABLE_SCAN_NODE_ID).getSplits().size() == 1); remoteTask.noMoreSplits(TABLE_SCAN_NODE_ID, lifespan); - poll(() -> testingTaskResource.getTaskSource(TABLE_SCAN_NODE_ID).getNoMoreSplitsForLifespan().size() == 1); + poll(() -> testingTaskResource.getTaskSplitAssignment(TABLE_SCAN_NODE_ID).getNoMoreSplitsForLifespan().size() == 1); remoteTask.noMoreSplits(TABLE_SCAN_NODE_ID); - poll(() -> testingTaskResource.getTaskSource(TABLE_SCAN_NODE_ID).isNoMoreSplits()); + poll(() -> testingTaskResource.getTaskSplitAssignment(TABLE_SCAN_NODE_ID).isNoMoreSplits()); remoteTask.cancel(); poll(() -> remoteTask.getTaskStatus().getState().isDone()); @@ -404,8 +404,8 @@ private void addSplit(RemoteTask remoteTask, TestingTaskResource testingTaskReso Lifespan lifespan = Lifespan.driverGroup(3); remoteTask.addSplits(ImmutableMultimap.of(TABLE_SCAN_NODE_ID, new Split(new CatalogName("test"), TestingSplit.createLocalSplit(), lifespan))); // wait for splits to be received by remote task - poll(() -> testingTaskResource.getTaskSource(TABLE_SCAN_NODE_ID) != null); - poll(() -> testingTaskResource.getTaskSource(TABLE_SCAN_NODE_ID).getSplits().size() == expectedSplitsCount); + poll(() -> testingTaskResource.getTaskSplitAssignment(TABLE_SCAN_NODE_ID) != null); + poll(() -> testingTaskResource.getTaskSplitAssignment(TABLE_SCAN_NODE_ID).getSplits().size() == expectedSplitsCount); } private RemoteTask createRemoteTask(HttpRemoteTaskFactory httpRemoteTaskFactory, Set outboundDynamicFilterIds) @@ -579,7 +579,7 @@ public synchronized TaskInfo getTaskInfo( return buildTaskInfo(); } - Map taskSourceMap = new HashMap<>(); + Map taskSplitAssignmentMap = new HashMap<>(); @POST @Path("{taskId}") @@ -590,8 +590,8 @@ public synchronized TaskInfo createOrUpdateTask( TaskUpdateRequest taskUpdateRequest, @Context UriInfo uriInfo) { - for (TaskSource source : taskUpdateRequest.getSources()) { - taskSourceMap.compute(source.getPlanNodeId(), (planNodeId, taskSource) -> taskSource == null ? source : taskSource.update(source)); + for (SplitAssignment splitAssignment : taskUpdateRequest.getSplitAssignments()) { + taskSplitAssignmentMap.compute(splitAssignment.getPlanNodeId(), (planNodeId, taskSplitAssignment) -> taskSplitAssignment == null ? splitAssignment : taskSplitAssignment.update(splitAssignment)); } if (!taskUpdateRequest.getDynamicFilterDomains().isEmpty()) { dynamicFiltersSentCounter++; @@ -602,13 +602,13 @@ public synchronized TaskInfo createOrUpdateTask( return buildTaskInfo(); } - public synchronized TaskSource getTaskSource(PlanNodeId planNodeId) + public synchronized SplitAssignment getTaskSplitAssignment(PlanNodeId planNodeId) { - TaskSource source = taskSourceMap.get(planNodeId); - if (source == null) { + SplitAssignment assignment = taskSplitAssignmentMap.get(planNodeId); + if (assignment == null) { return null; } - return new TaskSource(source.getPlanNodeId(), source.getSplits(), source.getNoMoreSplitsForLifespan(), source.isNoMoreSplits()); + return new SplitAssignment(assignment.getPlanNodeId(), assignment.getSplits(), assignment.getNoMoreSplitsForLifespan(), assignment.isNoMoreSplits()); } @GET diff --git a/core/trino-main/src/test/java/io/trino/sql/analyzer/TestFeaturesConfig.java b/core/trino-main/src/test/java/io/trino/sql/analyzer/TestFeaturesConfig.java index 1709da1a6457..5b61daac3d0e 100644 --- a/core/trino-main/src/test/java/io/trino/sql/analyzer/TestFeaturesConfig.java +++ b/core/trino-main/src/test/java/io/trino/sql/analyzer/TestFeaturesConfig.java @@ -121,7 +121,9 @@ public void testDefaults() .setRetryPolicy(RetryPolicy.NONE) .setRetryAttempts(4) .setRetryInitialDelay(new Duration(10, SECONDS)) - .setRetryMaxDelay(new Duration(1, MINUTES))); + .setRetryMaxDelay(new Duration(1, MINUTES)) + .setFaultTolerantExecutionTargetTaskInputSize(DataSize.of(1, GIGABYTE)) + .setFaultTolerantExecutionTargetTaskSplitCount(16)); } @Test @@ -205,6 +207,8 @@ public void testExplicitPropertyMappings() .put("retry-attempts", "0") .put("retry-initial-delay", "1m") .put("retry-max-delay", "1h") + .put("fault-tolerant-execution-target-task-input-size", "222MB") + .put("fault-tolerant-execution-target-task-split-count", "3") .build(); FeaturesConfig expected = new FeaturesConfig() @@ -284,7 +288,9 @@ public void testExplicitPropertyMappings() .setRetryPolicy(RetryPolicy.QUERY) .setRetryAttempts(0) .setRetryInitialDelay(new Duration(1, MINUTES)) - .setRetryMaxDelay(new Duration(1, HOURS)); + .setRetryMaxDelay(new Duration(1, HOURS)) + .setFaultTolerantExecutionTargetTaskInputSize(DataSize.of(222, MEGABYTE)) + .setFaultTolerantExecutionTargetTaskSplitCount(3); assertFullMapping(properties, expected); } } diff --git a/core/trino-server/src/main/provisio/presto.xml b/core/trino-server/src/main/provisio/presto.xml index 48190e7b21a7..7c65bfb437f3 100644 --- a/core/trino-server/src/main/provisio/presto.xml +++ b/core/trino-server/src/main/provisio/presto.xml @@ -270,4 +270,10 @@ + + + + + + diff --git a/core/trino-spi/src/main/java/io/trino/spi/Plugin.java b/core/trino-spi/src/main/java/io/trino/spi/Plugin.java index 9c69845b3e24..4517c262a274 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/Plugin.java +++ b/core/trino-spi/src/main/java/io/trino/spi/Plugin.java @@ -16,6 +16,7 @@ import io.trino.spi.block.BlockEncoding; import io.trino.spi.connector.ConnectorFactory; import io.trino.spi.eventlistener.EventListenerFactory; +import io.trino.spi.exchange.ExchangeManagerFactory; import io.trino.spi.resourcegroups.ResourceGroupConfigurationManagerFactory; import io.trino.spi.security.CertificateAuthenticatorFactory; import io.trino.spi.security.GroupProviderFactory; @@ -97,4 +98,9 @@ default Iterable getSessionPropertyC { return emptyList(); } + + default Iterable getExchangeManagerFactories() + { + return emptyList(); + } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java b/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java index 846b7ca7c247..06b5eb129722 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java +++ b/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java @@ -157,6 +157,7 @@ public enum StandardErrorCode INVALID_RESOURCE_GROUP(65561, INTERNAL_ERROR), SERIALIZATION_ERROR(65562, INTERNAL_ERROR), REMOTE_TASK_FAILED(65563, INTERNAL_ERROR), + EXCHANGE_MANAGER_NOT_CONFIGURED(65564, INTERNAL_ERROR), GENERIC_INSUFFICIENT_RESOURCES(131072, INSUFFICIENT_RESOURCES), EXCEEDED_GLOBAL_MEMORY_LIMIT(131073, INSUFFICIENT_RESOURCES), diff --git a/core/trino-spi/src/main/java/io/trino/spi/exchange/Exchange.java b/core/trino-spi/src/main/java/io/trino/spi/exchange/Exchange.java new file mode 100644 index 000000000000..fe4328ea55e4 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/exchange/Exchange.java @@ -0,0 +1,102 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.exchange; + +import io.airlift.slice.Slice; + +import javax.annotation.concurrent.ThreadSafe; + +import java.io.Closeable; +import java.util.List; +import java.util.concurrent.CompletableFuture; + +@ThreadSafe +public interface Exchange + extends Closeable +{ + /** + * Registers a new sink + * + * @param taskPartitionId uniquely identifies a dataset written to a sink + * @return {@link ExchangeSinkHandle} associated with the taskPartitionId. + * Must be passed to {@link #instantiateSink(ExchangeSinkHandle, int)} along with an attempt id to create a sink instance + */ + ExchangeSinkHandle addSink(int taskPartitionId); + + /** + * Called when no more sinks will be added with {@link #addSink(int)}. + * New sink instances for an existing sink may still be added with {@link #instantiateSink(ExchangeSinkHandle, int)}. + */ + void noMoreSinks(); + + /** + * Registers a sink instance for an attempt. + *

+ * Attempts are expected to produce the same data. + *

+ * The implementation must ensure the data written by unsuccessful attempts is safely discarded. + *

+ * When more than a single attempt is successful the implementation must pick one and discard the output of the other attempts + * + * @param sinkHandle - handle returned by addSink + * @param taskAttemptId - attempt id + * @return ExchangeSinkInstanceHandle to be sent to a worker that is needed to create an {@link ExchangeSink} instance using + * {@link ExchangeManager#createSink(ExchangeSinkInstanceHandle)} + */ + ExchangeSinkInstanceHandle instantiateSink(ExchangeSinkHandle sinkHandle, int taskAttemptId); + + /** + * Called by the engine when an attempt finishes successfully + */ + void sinkFinished(ExchangeSinkInstanceHandle handle); + + /** + * Returns a future containing handles to be used to read data from an exchange. + *

+ * Future must be resolved when the data is available to be read. + *

+ * The implementation is expected to return one handle per output partition (see {@link ExchangeSink#add(int, Slice)}) + *

+ * Partitions can be further split if needed by calling {@link #split(ExchangeSourceHandle, long)} + * + * @return Future containing a list of {@link ExchangeSourceHandle} to be sent to a + * worker that is needed to create an {@link ExchangeSource} using {@link ExchangeManager#createSource(List)} + */ + CompletableFuture> getSourceHandles(); + + /** + * Splits an {@link ExchangeSourceHandle} into a number of smaller partitions. + *

+ * Exchange implementation is allowed to return {@link ExchangeSourceHandle} even before all the data + * is written to an exchange. At the moment when the method is called it may not be possible to + * complete the split operation. This methods returns a {@link ExchangeSourceSplitter} object + * that allows an iterative splitting while the data is still being written to an exchange. + * + * @param handle returned by the {@link #getSourceHandles()} + * @param targetSizeInBytes desired maximum size of a single partition produced by {@link ExchangeSourceSplitter} + * @return {@link ExchangeSourceSplitter} to be used for iterative splitting of a given partition + */ + ExchangeSourceSplitter split(ExchangeSourceHandle handle, long targetSizeInBytes); + + /** + * Returns statistics (such as size in bytes) for a partition represented by a {@link ExchangeSourceHandle} + * + * @param handle returned by the {@link #getSourceHandles()} or {@link ExchangeSourceSplitter#getNext()} + * @return object containing statistics for a given {@link ExchangeSourceHandle} + */ + ExchangeSourceStatistics getExchangeSourceStatistics(ExchangeSourceHandle handle); + + @Override + void close(); +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeContext.java b/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeContext.java new file mode 100644 index 000000000000..de0ba03cb5eb --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeContext.java @@ -0,0 +1,40 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.exchange; + +import io.trino.spi.QueryId; + +import static java.util.Objects.requireNonNull; + +public class ExchangeContext +{ + private final QueryId queryId; + private final int stageId; + + public ExchangeContext(QueryId queryId, int stageId) + { + this.queryId = requireNonNull(queryId, "queryId is null"); + this.stageId = stageId; + } + + public QueryId getQueryId() + { + return queryId; + } + + public int getStageId() + { + return stageId; + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeManager.java b/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeManager.java new file mode 100644 index 000000000000..44f28acd6cf9 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeManager.java @@ -0,0 +1,84 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.exchange; + +import io.airlift.slice.Slice; + +import javax.annotation.concurrent.ThreadSafe; + +import java.util.List; + +/** + * Service provider interface for an external exchange + *

+ * Used by the engine to exchange data at stage boundaries + *

+ * External exchange is responsible for accepting partitioned data from multiple upstream + * tasks, grouping that data based on the partitionId + * (see {@link ExchangeSink#add(int, Slice)}) and allowing the data to be consumed a + * partition at a time by a set of downstream tasks. + *

+ * To support failure recovery an external exchange implementation is also responsible + * for data deduplication in an event of a task retry or a speculative execution of a task + * (when two identical tasks are running at the same time). The deduplication must be done + * based on the sink identifier (see {@link Exchange#addSink(int)}). The implementation should + * assume that the data written for the same {@link ExchangeSinkHandle} by multiple sink + * instances (see {@link Exchange#instantiateSink(ExchangeSinkHandle, int)}) is identical + * and the data written by an arbitrary instance can be chosen to be delivered while the + * data written by other instances must be safely discarded + */ +@ThreadSafe +public interface ExchangeManager +{ + /** + * Called by the coordinator to initiate an external exchange between a pair of stages + * + * @param context contains various information about the query and stage being executed + * @param outputPartitionCount number of distinct partitions to be created (grouped) by the exchange. + * Values of the partitionId parameter of the {@link ExchangeSink#add(int, Slice)} method + * will be in the [0..outputPartitionCount) range + * @return {@link Exchange} object to be used by the coordinator to interact with the external exchange + */ + Exchange create(ExchangeContext context, int outputPartitionCount); + + /** + * Called by a worker to create an {@link ExchangeSink} for a specific sink instance. + *

+ * A new sink instance is created by the coordinator for every task attempt (see {@link Exchange#instantiateSink(ExchangeSinkHandle, int)}) + * + * @param handle returned by {@link Exchange#instantiateSink(ExchangeSinkHandle, int)} + * @return {@link ExchangeSink} used by the engine to write data to an exchange + */ + ExchangeSink createSink(ExchangeSinkInstanceHandle handle); + + /** + * Called by a worker to create an {@link ExchangeSource} to read data corresponding to + * a given list of exchange source handles. + *

+ * Usually a single {@link ExchangeSourceHandle} corresponds to a single output partition + * (see {@link ExchangeSink#add(int, Slice)}) unless a partition got split by calling + * {@link Exchange#split(ExchangeSourceHandle, long)}. + *

+ * Based on the partition statistic (such as partition size) coordinator may also decide + * to process several partitions by the same task. In such scenarios the handles + * list may contain more than a single element. + * + * @param handles list of {@link ExchangeSourceHandle}'s describing what exchange data to + * read. The full list of handles is returned by {@link Exchange#getSourceHandles}. + * The coordinator decided what items from that list should be handled by what task and creates + * sub-list that are further getting sent to a worker to be read. + * @return {@link ExchangeSource} used by the engine to read data from an exchange + */ + ExchangeSource createSource(List handles); +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeManagerFactory.java b/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeManagerFactory.java new file mode 100644 index 000000000000..6ac7b5e5dfbe --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeManagerFactory.java @@ -0,0 +1,25 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.exchange; + +import java.util.Map; + +public interface ExchangeManagerFactory +{ + String getName(); + + ExchangeManager create(Map config); + + ExchangeManagerHandleResolver getHandleResolver(); +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeManagerHandleResolver.java b/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeManagerHandleResolver.java new file mode 100644 index 000000000000..b349c6165206 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeManagerHandleResolver.java @@ -0,0 +1,21 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.exchange; + +public interface ExchangeManagerHandleResolver +{ + Class getExchangeSinkInstanceHandleClass(); + + Class getExchangeSourceHandleHandleClass(); +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSink.java b/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSink.java new file mode 100644 index 000000000000..c4b354c1b0c3 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSink.java @@ -0,0 +1,54 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.exchange; + +import io.airlift.slice.Slice; + +import javax.annotation.concurrent.ThreadSafe; + +import java.util.concurrent.CompletableFuture; + +@ThreadSafe +public interface ExchangeSink +{ + CompletableFuture NOT_BLOCKED = CompletableFuture.completedFuture(null); + + /** + * Returns a future that will be completed when the exchange sink becomes + * unblocked. If the exchange sink is not blocked, this method should return + * {@code NOT_BLOCKED} + */ + CompletableFuture isBlocked(); + + /** + * Appends arbitrary {@code data} to a partition specified by {@code partitionId} + */ + void add(int partitionId, Slice data); + + /** + * Get the total memory that needs to be reserved in the general memory pool. + * This memory should include any buffers, etc. that are used for writing data + */ + long getSystemMemoryUsage(); + + /** + * Notifies the exchange sink that no more data will be appended + */ + void finish(); + + /** + * Notifies the exchange that the write operation has been aborted + */ + void abort(); +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSinkHandle.java b/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSinkHandle.java new file mode 100644 index 000000000000..1ea5b2f0cd64 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSinkHandle.java @@ -0,0 +1,21 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.exchange; + +/** + * Implementation is expected to be Jackson serializable + */ +public interface ExchangeSinkHandle +{ +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSinkInstanceHandle.java b/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSinkInstanceHandle.java new file mode 100644 index 000000000000..0e8f22ca2286 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSinkInstanceHandle.java @@ -0,0 +1,21 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.exchange; + +/** + * Implementation is expected to be Jackson serializable + */ +public interface ExchangeSinkInstanceHandle +{ +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSource.java b/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSource.java new file mode 100644 index 000000000000..d952e5b77da3 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSource.java @@ -0,0 +1,61 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.exchange; + +import io.airlift.slice.Slice; + +import javax.annotation.Nullable; +import javax.annotation.concurrent.ThreadSafe; + +import java.io.Closeable; +import java.util.concurrent.CompletableFuture; + +@ThreadSafe +public interface ExchangeSource + extends Closeable +{ + CompletableFuture NOT_BLOCKED = CompletableFuture.completedFuture(null); + + /** + * Returns a future that will be completed when the exchange source becomes + * unblocked. If the exchange source is not blocked, this method should return + * {@code NOT_BLOCKED} + */ + CompletableFuture isBlocked(); + + /** + * Will this exchange source product more data? + */ + boolean isFinished(); + + /** + * Gets the next chunk of data. This method is allowed to return null. + *

+ * The engine will keep calling this method until {@link #isFinished()} returns {@code true} + * + * @return data written to an exchange using {@link ExchangeSink#add(int, Slice)}. + * The slice is always returned as a whole as written (the exchange does not split and does not merge slices). + */ + @Nullable + Slice read(); + + /** + * Get the total memory that needs to be reserved in the general memory pool. + * This memory should include any buffers, etc. that are used for reading data + */ + long getSystemMemoryUsage(); + + @Override + void close(); +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSourceHandle.java b/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSourceHandle.java new file mode 100644 index 000000000000..97f0495f8211 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSourceHandle.java @@ -0,0 +1,22 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.exchange; + +/** + * Implementation is expected to be Jackson serializable + */ +public interface ExchangeSourceHandle +{ + int getPartitionId(); +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSourceSplitter.java b/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSourceSplitter.java new file mode 100644 index 000000000000..d7b3572a4d48 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSourceSplitter.java @@ -0,0 +1,39 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.exchange; + +import java.io.Closeable; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; + +public interface ExchangeSourceSplitter + extends Closeable +{ + CompletableFuture NOT_BLOCKED = CompletableFuture.completedFuture(null); + + /** + * Returns a future that will be completed when the splitter becomes + * unblocked. If the splitter is not blocked, this method should return + * {@code NOT_BLOCKED} + */ + CompletableFuture isBlocked(); + + /** + * Returns next sub partition or {@link Optional#empty()} when the splitting process is finished. + */ + Optional getNext(); + + @Override + void close(); +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSourceStatistics.java b/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSourceStatistics.java new file mode 100644 index 000000000000..ca52fcc203d7 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSourceStatistics.java @@ -0,0 +1,29 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.exchange; + +public class ExchangeSourceStatistics +{ + private final long sizeInBytes; + + public ExchangeSourceStatistics(long sizeInBytes) + { + this.sizeInBytes = sizeInBytes; + } + + public long getSizeInBytes() + { + return sizeInBytes; + } +} diff --git a/etc/exchange-manager.properties b/etc/exchange-manager.properties new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/plugin/trino-exchange/pom.xml b/plugin/trino-exchange/pom.xml new file mode 100644 index 000000000000..8225aab97bc7 --- /dev/null +++ b/plugin/trino-exchange/pom.xml @@ -0,0 +1,193 @@ + + + 4.0.0 + + + trino-root + io.trino + 366-SNAPSHOT + ../../pom.xml + + + trino-exchange + Trino - Exchange + trino-plugin + + + ${project.parent.basedir} + 2.17.102 + + + + + + software.amazon.awssdk + bom + pom + ${awsjavasdk.version} + import + + + + + + + io.airlift + bootstrap + + + + io.airlift + concurrent + + + + io.airlift + configuration + + + + io.airlift + log + + + + io.airlift + units + + + + com.google.code.findbugs + jsr305 + + + + com.google.guava + guava + + + + com.google.inject + guice + + + + javax.annotation + javax.annotation-api + + + + javax.inject + javax.inject + + + + javax.validation + validation-api + + + + org.reactivestreams + reactive-streams + 1.0.3 + + + + software.amazon.awssdk + auth + + + + software.amazon.awssdk + aws-core + + + + software.amazon.awssdk + regions + + + + software.amazon.awssdk + s3 + + + commons-logging + commons-logging + + + + + + software.amazon.awssdk + sdk-core + + + + software.amazon.awssdk + utils + + + + + io.trino + trino-spi + provided + + + + io.airlift + slice + provided + + + + com.fasterxml.jackson.core + jackson-annotations + provided + + + + org.openjdk.jol + jol-core + provided + + + + + io.trino + trino-main + test + + + + io.trino + trino-testing + test + + + + io.trino + trino-tpch + test + + + + io.trino.tpch + tpch + test + + + + org.testcontainers + testcontainers + test + + + + org.testng + testng + test + + + diff --git a/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/ExchangeStorageWriter.java b/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/ExchangeStorageWriter.java new file mode 100644 index 000000000000..4d2cf2a56c0c --- /dev/null +++ b/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/ExchangeStorageWriter.java @@ -0,0 +1,31 @@ +package io.trino.plugin.exchange; +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import com.google.common.util.concurrent.ListenableFuture; +import io.airlift.slice.Slice; + +import java.io.Closeable; +import java.io.IOException; + +public interface ExchangeStorageWriter + extends Closeable +{ + ListenableFuture write(Slice slice) + throws IOException; + + @Override + void close() + throws IOException; +} diff --git a/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/FileSystemExchange.java b/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/FileSystemExchange.java new file mode 100644 index 000000000000..d0eea9c20aef --- /dev/null +++ b/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/FileSystemExchange.java @@ -0,0 +1,290 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.exchange; + +import com.google.common.collect.ArrayListMultimap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Multimap; +import io.trino.spi.exchange.Exchange; +import io.trino.spi.exchange.ExchangeContext; +import io.trino.spi.exchange.ExchangeSinkHandle; +import io.trino.spi.exchange.ExchangeSinkInstanceHandle; +import io.trino.spi.exchange.ExchangeSourceHandle; +import io.trino.spi.exchange.ExchangeSourceSplitter; +import io.trino.spi.exchange.ExchangeSourceStatistics; + +import javax.annotation.concurrent.GuardedBy; +import javax.crypto.SecretKey; + +import java.io.File; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.net.URI; +import java.security.Key; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.plugin.exchange.FileSystemExchangeManager.PATH_SEPARATOR; +import static io.trino.plugin.exchange.FileSystemExchangeSink.COMMITTED_MARKER_FILE_NAME; +import static io.trino.plugin.exchange.FileSystemExchangeSink.DATA_FILE_SUFFIX; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class FileSystemExchange + implements Exchange +{ + private static final Pattern PARTITION_FILE_NAME_PATTERN = Pattern.compile("(\\d+)\\.data"); + + private final URI baseDirectory; + private final FileSystemExchangeStorage exchangeStorage; + private final ExchangeContext exchangeContext; + private final int outputPartitionCount; + private final Optional secretKey; + + @GuardedBy("this") + private final Set allPartitions = new HashSet<>(); + @GuardedBy("this") + private final Set finishedPartitions = new HashSet<>(); + @GuardedBy("this") + private boolean noMoreSinks; + + private final CompletableFuture> exchangeSourceHandlesFuture = new CompletableFuture<>(); + @GuardedBy("this") + private boolean exchangeSourceHandlesCreated; + + public FileSystemExchange(URI baseDirectory, FileSystemExchangeStorage exchangeStorage, ExchangeContext exchangeContext, int outputPartitionCount, Optional secretKey) + { + this.baseDirectory = requireNonNull(baseDirectory, "baseDirectory is null"); + this.exchangeStorage = requireNonNull(exchangeStorage, "exchangeStorage is null"); + this.exchangeContext = requireNonNull(exchangeContext, "exchangeContext is null"); + this.outputPartitionCount = outputPartitionCount; + this.secretKey = requireNonNull(secretKey, "secretKey is null"); + } + + public void initialize() + { + try { + exchangeStorage.createDirectories(getExchangeDirectory()); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @Override + public synchronized ExchangeSinkHandle addSink(int partitionId) + { + FileSystemExchangeSinkHandle sinkHandle = new FileSystemExchangeSinkHandle(partitionId, secretKey.map(Key::getEncoded)); + allPartitions.add(partitionId); + return sinkHandle; + } + + @Override + public void noMoreSinks() + { + synchronized (this) { + noMoreSinks = true; + } + checkInputReady(); + } + + @Override + public ExchangeSinkInstanceHandle instantiateSink(ExchangeSinkHandle sinkHandle, int taskAttemptId) + { + FileSystemExchangeSinkHandle fileSystemExchangeSinkHandle = (FileSystemExchangeSinkHandle) sinkHandle; + URI outputDirectory = getExchangeDirectory() + .resolve(fileSystemExchangeSinkHandle.getPartitionId() + PATH_SEPARATOR) + .resolve(taskAttemptId + PATH_SEPARATOR); + try { + exchangeStorage.createDirectories(outputDirectory); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + + return new FileSystemExchangeSinkInstanceHandle(fileSystemExchangeSinkHandle, outputDirectory, outputPartitionCount); + } + + @Override + public void sinkFinished(ExchangeSinkInstanceHandle handle) + { + synchronized (this) { + FileSystemExchangeSinkInstanceHandle instanceHandle = (FileSystemExchangeSinkInstanceHandle) handle; + finishedPartitions.add(instanceHandle.getSinkHandle().getPartitionId()); + } + checkInputReady(); + } + + private void checkInputReady() + { + verify(!Thread.holdsLock(this)); + List exchangeSourceHandles = null; + synchronized (this) { + if (exchangeSourceHandlesCreated) { + return; + } + if (noMoreSinks && finishedPartitions.containsAll(allPartitions)) { + // input is ready, create exchange source handles + exchangeSourceHandles = createExchangeSourceHandles(); + exchangeSourceHandlesCreated = true; + } + } + if (exchangeSourceHandles != null) { + exchangeSourceHandlesFuture.complete(exchangeSourceHandles); + } + } + + private synchronized List createExchangeSourceHandles() + { + Multimap partitionFilesMap = ArrayListMultimap.create(); + for (Integer partitionId : finishedPartitions) { + URI committedAttemptPath = getCommittedAttemptPath(partitionId); + Map partitions = getCommittedPartitions(committedAttemptPath); + partitions.forEach(partitionFilesMap::put); + } + + ImmutableList.Builder result = ImmutableList.builder(); + for (Integer partitionId : partitionFilesMap.keySet()) { + result.add(new FileSystemExchangeSourceHandle(partitionId, ImmutableList.copyOf(partitionFilesMap.get(partitionId)), secretKey.map(SecretKey::getEncoded))); + } + return result.build(); + } + + private URI getCommittedAttemptPath(int partitionId) + { + URI sinkOutputBasePath = getExchangeDirectory() + .resolve(partitionId + PATH_SEPARATOR); + try { + List attemptPaths = exchangeStorage.listDirectories(sinkOutputBasePath).collect(toImmutableList()); + checkState(!attemptPaths.isEmpty(), "no attempts found under sink output path %s", sinkOutputBasePath); + + return attemptPaths.stream() + .filter(this::isCommitted) + .findFirst() + .orElseThrow(() -> new IllegalStateException(format("no committed attempts found under sink output path %s", sinkOutputBasePath))); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + private boolean isCommitted(URI attemptPath) + { + URI commitMarkerFilePath = attemptPath.resolve(COMMITTED_MARKER_FILE_NAME); + try { + return exchangeStorage.exists(commitMarkerFilePath); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + private Map getCommittedPartitions(URI committedAttemptPath) + { + try { + List partitionFiles = exchangeStorage.listFiles(committedAttemptPath) + .filter(file -> file.getPath().endsWith(DATA_FILE_SUFFIX)) + .collect(toImmutableList()); + ImmutableMap.Builder result = ImmutableMap.builder(); + for (URI partitionFile : partitionFiles) { + Matcher matcher = PARTITION_FILE_NAME_PATTERN.matcher(new File(partitionFile.getPath()).getName()); + checkState(matcher.matches(), "unexpected partition file: %s", partitionFile); + int partitionId = Integer.parseInt(matcher.group(1)); + result.put(partitionId, partitionFile); + } + return result.build(); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + private URI getExchangeDirectory() + { + return baseDirectory.resolve(exchangeContext.getQueryId() + "." + exchangeContext.getStageId() + PATH_SEPARATOR); + } + + @Override + public CompletableFuture> getSourceHandles() + { + return exchangeSourceHandlesFuture; + } + + @Override + public ExchangeSourceSplitter split(ExchangeSourceHandle handle, long targetSizeInBytes) + { + FileSystemExchangeSourceHandle sourceHandle = (FileSystemExchangeSourceHandle) handle; + Iterator filesIterator = sourceHandle.getFiles().iterator(); + return new ExchangeSourceSplitter() + { + @Override + public CompletableFuture isBlocked() + { + return NOT_BLOCKED; + } + + @Override + public Optional getNext() + { + if (filesIterator.hasNext()) { + return Optional.of(new FileSystemExchangeSourceHandle(sourceHandle.getPartitionId(), ImmutableList.of(filesIterator.next()), secretKey.map(SecretKey::getEncoded))); + } + return Optional.empty(); + } + + @Override + public void close() + { + } + }; + } + + @Override + public ExchangeSourceStatistics getExchangeSourceStatistics(ExchangeSourceHandle handle) + { + FileSystemExchangeSourceHandle sourceHandle = (FileSystemExchangeSourceHandle) handle; + long sizeInBytes = 0; + for (URI file : sourceHandle.getFiles()) { + try { + sizeInBytes += exchangeStorage.size(file, secretKey); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + return new ExchangeSourceStatistics(sizeInBytes); + } + + @Override + public void close() + { + try { + exchangeStorage.deleteRecursively(getExchangeDirectory()); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } +} diff --git a/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/FileSystemExchangeConfig.java b/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/FileSystemExchangeConfig.java new file mode 100644 index 000000000000..60ce2f15762a --- /dev/null +++ b/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/FileSystemExchangeConfig.java @@ -0,0 +1,46 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.exchange; + +import io.airlift.configuration.Config; + +public class FileSystemExchangeConfig +{ + private String baseDirectory; + private boolean exchangeEncryptionEnabled; + + public String getBaseDirectory() + { + return baseDirectory; + } + + @Config("exchange.base-directory") + public FileSystemExchangeConfig setBaseDirectory(String baseDirectory) + { + this.baseDirectory = baseDirectory; + return this; + } + + public boolean isExchangeEncryptionEnabled() + { + return exchangeEncryptionEnabled; + } + + @Config("exchange.encryption-enabled") + public FileSystemExchangeConfig setExchangeEncryptionEnabled(boolean exchangeEncryptionEnabled) + { + this.exchangeEncryptionEnabled = exchangeEncryptionEnabled; + return this; + } +} diff --git a/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/FileSystemExchangeManager.java b/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/FileSystemExchangeManager.java new file mode 100644 index 000000000000..3b1c0d702656 --- /dev/null +++ b/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/FileSystemExchangeManager.java @@ -0,0 +1,109 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.exchange; + +import com.google.common.collect.ImmutableList; +import io.trino.spi.TrinoException; +import io.trino.spi.exchange.Exchange; +import io.trino.spi.exchange.ExchangeContext; +import io.trino.spi.exchange.ExchangeManager; +import io.trino.spi.exchange.ExchangeSink; +import io.trino.spi.exchange.ExchangeSinkInstanceHandle; +import io.trino.spi.exchange.ExchangeSource; +import io.trino.spi.exchange.ExchangeSourceHandle; + +import javax.crypto.KeyGenerator; +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; +import javax.inject.Inject; + +import java.net.URI; +import java.security.NoSuchAlgorithmException; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; +import static java.util.Objects.requireNonNull; + +public class FileSystemExchangeManager + implements ExchangeManager +{ + public static final String AES = "AES"; + public static final int KEY_BITS = 256; + public static final String PATH_SEPARATOR = "/"; + + private final FileSystemExchangeStorage exchangeStorage; + private final URI baseDirectory; + private final boolean exchangeEncryptionEnabled; + + @Inject + public FileSystemExchangeManager(FileSystemExchangeStorage exchangeStorage, FileSystemExchangeConfig fileSystemExchangeConfig) + { + this.exchangeStorage = requireNonNull(exchangeStorage, "exchangeStorage is null"); + String baseDirectory = requireNonNull(fileSystemExchangeConfig.getBaseDirectory(), "baseDirectory is null"); + if (!baseDirectory.endsWith(PATH_SEPARATOR)) { + // This is needed as URI's resolve method expects directories to end with '/' + baseDirectory += PATH_SEPARATOR; + } + this.baseDirectory = URI.create(baseDirectory); + this.exchangeEncryptionEnabled = fileSystemExchangeConfig.isExchangeEncryptionEnabled(); + } + + @Override + public Exchange create(ExchangeContext context, int outputPartitionCount) + { + Optional secretKey = Optional.empty(); + if (exchangeEncryptionEnabled) { + try { + KeyGenerator keyGenerator = KeyGenerator.getInstance(AES); + keyGenerator.init(KEY_BITS); + secretKey = Optional.of(keyGenerator.generateKey()); + } + catch (NoSuchAlgorithmException e) { + throw new TrinoException(GENERIC_INTERNAL_ERROR, "Failed to generate new secret key: " + e.getMessage(), e); + } + } + FileSystemExchange exchange = new FileSystemExchange(baseDirectory, exchangeStorage, context, outputPartitionCount, secretKey); + exchange.initialize(); + return exchange; + } + + @Override + public ExchangeSink createSink(ExchangeSinkInstanceHandle handle) + { + FileSystemExchangeSinkInstanceHandle instanceHandle = (FileSystemExchangeSinkInstanceHandle) handle; + return new FileSystemExchangeSink( + exchangeStorage, + instanceHandle.getOutputDirectory(), + instanceHandle.getOutputPartitionCount(), + instanceHandle.getSinkHandle().getSecretKey().map(key -> new SecretKeySpec(key, 0, key.length, AES))); + } + + @Override + public ExchangeSource createSource(List handles) + { + List files = handles.stream() + .map(FileSystemExchangeSourceHandle.class::cast) + .flatMap(handle -> handle.getFiles().stream()) + .collect(toImmutableList()); + ImmutableList.Builder> secretKeys = ImmutableList.builder(); + for (ExchangeSourceHandle handle : handles) { + FileSystemExchangeSourceHandle sourceHandle = (FileSystemExchangeSourceHandle) handle; + secretKeys.addAll(Collections.nCopies(sourceHandle.getFiles().size(), sourceHandle.getSecretKey().map(key -> new SecretKeySpec(key, 0, key.length, AES)))); + } + return new FileSystemExchangeSource(exchangeStorage, files, secretKeys.build()); + } +} diff --git a/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/FileSystemExchangeManagerFactory.java b/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/FileSystemExchangeManagerFactory.java new file mode 100644 index 000000000000..e68956f101f4 --- /dev/null +++ b/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/FileSystemExchangeManagerFactory.java @@ -0,0 +1,70 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.exchange; + +import com.google.inject.Injector; +import io.airlift.bootstrap.Bootstrap; +import io.trino.spi.exchange.ExchangeManager; +import io.trino.spi.exchange.ExchangeManagerFactory; +import io.trino.spi.exchange.ExchangeManagerHandleResolver; +import io.trino.spi.exchange.ExchangeSinkInstanceHandle; +import io.trino.spi.exchange.ExchangeSourceHandle; + +import java.util.Map; + +import static java.util.Objects.requireNonNull; + +public class FileSystemExchangeManagerFactory + implements ExchangeManagerFactory +{ + @Override + public String getName() + { + return "filesystem"; + } + + @Override + public ExchangeManager create(Map config) + { + requireNonNull(config, "config is null"); + + Bootstrap app = new Bootstrap(new FileSystemExchangeModule()); + + Injector injector = app + .doNotInitializeLogging() + .setRequiredConfigurationProperties(config) + .initialize(); + + return injector.getInstance(FileSystemExchangeManager.class); + } + + @Override + public ExchangeManagerHandleResolver getHandleResolver() + { + return new ExchangeManagerHandleResolver() + { + @Override + public Class getExchangeSinkInstanceHandleClass() + { + return FileSystemExchangeSinkInstanceHandle.class; + } + + @Override + public Class getExchangeSourceHandleHandleClass() + { + return FileSystemExchangeSourceHandle.class; + } + }; + } +} diff --git a/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/FileSystemExchangeModule.java b/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/FileSystemExchangeModule.java new file mode 100644 index 000000000000..e1db43be9c24 --- /dev/null +++ b/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/FileSystemExchangeModule.java @@ -0,0 +1,53 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.exchange; + +import com.google.common.collect.ImmutableSet; +import com.google.inject.Binder; +import com.google.inject.Scopes; +import io.airlift.configuration.AbstractConfigurationAwareModule; +import io.trino.plugin.exchange.local.LocalFileSystemExchangeStorage; +import io.trino.plugin.exchange.s3.ExchangeS3Config; +import io.trino.plugin.exchange.s3.S3FileSystemExchangeStorage; +import io.trino.spi.TrinoException; + +import java.net.URI; + +import static io.airlift.configuration.ConfigBinder.configBinder; +import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class FileSystemExchangeModule + extends AbstractConfigurationAwareModule +{ + @Override + protected void setup(Binder binder) + { + binder.bind(FileSystemExchangeManager.class).in(Scopes.SINGLETON); + + FileSystemExchangeConfig fileSystemExchangeConfig = buildConfigObject(FileSystemExchangeConfig.class); + String scheme = URI.create(requireNonNull(fileSystemExchangeConfig.getBaseDirectory(), "baseDirectory is null")).getScheme(); + if (scheme == null || scheme.equals("file")) { + binder.bind(FileSystemExchangeStorage.class).to(LocalFileSystemExchangeStorage.class).in(Scopes.SINGLETON); + } + else if (ImmutableSet.of("s3", "s3a", "s3n").contains(scheme)) { + binder.bind(FileSystemExchangeStorage.class).to(S3FileSystemExchangeStorage.class).in(Scopes.SINGLETON); + configBinder(binder).bindConfig(ExchangeS3Config.class); + } + else { + throw new TrinoException(NOT_SUPPORTED, format("Scheme %s is not supported as exchange spooling storage", scheme)); + } + } +} diff --git a/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/FileSystemExchangePlugin.java b/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/FileSystemExchangePlugin.java new file mode 100644 index 000000000000..04e5b08dc19c --- /dev/null +++ b/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/FileSystemExchangePlugin.java @@ -0,0 +1,28 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.exchange; + +import com.google.common.collect.ImmutableList; +import io.trino.spi.Plugin; +import io.trino.spi.exchange.ExchangeManagerFactory; + +public class FileSystemExchangePlugin + implements Plugin +{ + @Override + public Iterable getExchangeManagerFactories() + { + return ImmutableList.of(new FileSystemExchangeManagerFactory()); + } +} diff --git a/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/FileSystemExchangeSink.java b/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/FileSystemExchangeSink.java new file mode 100644 index 000000000000..f2d9821a216a --- /dev/null +++ b/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/FileSystemExchangeSink.java @@ -0,0 +1,249 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.exchange; + +import io.airlift.log.Logger; +import io.airlift.slice.Slice; +import io.airlift.slice.SliceOutput; +import io.airlift.slice.Slices; +import io.trino.spi.exchange.ExchangeSink; +import org.openjdk.jol.info.ClassLayout; + +import javax.annotation.concurrent.GuardedBy; +import javax.crypto.SecretKey; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.net.URI; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.LinkedBlockingQueue; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static java.lang.Math.min; +import static java.util.Objects.requireNonNull; + +public class FileSystemExchangeSink + implements ExchangeSink +{ + private static final Logger log = Logger.get(FileSystemExchangeSink.class); + + public static final String COMMITTED_MARKER_FILE_NAME = "committed"; + public static final String DATA_FILE_SUFFIX = ".data"; + + private static final int INSTANCE_SIZE = ClassLayout.parseClass(FileSystemExchangeSink.class).instanceSize(); + + private final FileSystemExchangeStorage exchangeStorage; + private final URI outputDirectory; + private final int outputPartitionCount; + private final int numBuffers; + private final Optional secretKey; + private final int writeBufferSize; + + private final Map writers = new ConcurrentHashMap<>(); + private final Map pendingBuffers = new ConcurrentHashMap<>(); + private final LinkedBlockingQueue freeBuffers = new LinkedBlockingQueue<>(); + private volatile boolean committed; + private volatile boolean closed; + @GuardedBy("this") + private CompletableFuture blockedFuture = new CompletableFuture<>(); + + public FileSystemExchangeSink(FileSystemExchangeStorage exchangeStorage, URI outputDirectory, int outputPartitionCount, Optional secretKey) + { + this.exchangeStorage = requireNonNull(exchangeStorage, "exchangeStorage is null"); + this.outputDirectory = requireNonNull(outputDirectory, "outputDirectory is null"); + this.outputPartitionCount = outputPartitionCount; + this.numBuffers = outputPartitionCount * 2; + this.secretKey = requireNonNull(secretKey, "secretKey is null"); + this.writeBufferSize = exchangeStorage.getWriteBufferSizeInBytes(); + + for (int i = 0; i < numBuffers; ++i) { + freeBuffers.add(Slices.allocate(writeBufferSize).getOutput()); + } + } + + @Override + public CompletableFuture isBlocked() + { + if (freeBuffers.isEmpty() && pendingBuffers.size() < outputPartitionCount) { + synchronized (this) { + if (blockedFuture.isDone()) { + blockedFuture = new CompletableFuture<>(); + } + return blockedFuture; + } + } + else { + return NOT_BLOCKED; + } + } + + @Override + public void add(int partitionId, Slice data) + { + checkArgument(partitionId < outputPartitionCount, "partition id is expected to be less than %s: %s", outputPartitionCount, partitionId); + checkState(!committed, "already committed"); + if (closed) { + return; + } + + writers.computeIfAbsent(partitionId, this::createWriter); + synchronized (writers.get(partitionId)) { + writeToExchangeStorage(partitionId, Slices.wrappedIntArray(data.length())); + writeToExchangeStorage(partitionId, data); + } + } + + private ExchangeStorageWriter createWriter(int partitionId) + { + URI outputPath = outputDirectory.resolve(partitionId + DATA_FILE_SUFFIX); + try { + return exchangeStorage.createExchangeStorageWriter(outputPath, secretKey); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + private void writeToExchangeStorage(int partitionId, Slice slice) + { + int position = 0; + while (position < slice.length()) { + SliceOutput pendingBuffer = pendingBuffers.computeIfAbsent(partitionId, ignored -> { + try { + return freeBuffers.take(); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + }); + int writableBytes = min(pendingBuffer.writableBytes(), slice.length() - position); + pendingBuffer.writeBytes(slice.getBytes(position, writableBytes)); + position += writableBytes; + + flushIfNeeded(partitionId, false); + } + } + + private void flushIfNeeded(int partitionId, boolean finished) + { + SliceOutput buffer = pendingBuffers.get(partitionId); + if (!buffer.isWritable() || finished) { + if (!buffer.isWritable()) { + pendingBuffers.remove(partitionId); + } + try { + writers.get(partitionId).write(buffer.slice()).addListener(() -> { + buffer.reset(); + freeBuffers.add(buffer); + synchronized (this) { + if (!blockedFuture.isDone()) { + blockedFuture.complete(null); + } + } + }, directExecutor()); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + } + + @Override + public long getSystemMemoryUsage() + { + return INSTANCE_SIZE + (long) numBuffers * writeBufferSize; + } + + @Override + public void finish() + { + if (closed) { + return; + } + try { + for (Integer partitionId : writers.keySet()) { + flushIfNeeded(partitionId, true); + try { + pendingBuffers.get(partitionId).close(); + writers.get(partitionId).close(); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + for (SliceOutput freeBuffer : freeBuffers) { + try { + freeBuffer.close(); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + try { + exchangeStorage.createEmptyFile(outputDirectory.resolve(COMMITTED_MARKER_FILE_NAME)); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + catch (Throwable t) { + abort(); + throw t; + } + pendingBuffers.clear(); + freeBuffers.clear(); + committed = true; + closed = true; + } + + @Override + public void abort() + { + if (closed) { + return; + } + closed = true; + for (Integer partitionId : writers.keySet()) { + try { + pendingBuffers.get(partitionId).close(); + writers.get(partitionId).close(); + } + catch (IOException e) { + log.warn(e, "Error closing pending buffer and writer for exchanges"); + } + } + for (SliceOutput freeBuffer : freeBuffers) { + try { + freeBuffer.close(); + } + catch (IOException e) { + log.warn(e, "Error closing free buffer for exchanges"); + } + } + pendingBuffers.clear(); + freeBuffers.clear(); + try { + exchangeStorage.deleteRecursively(outputDirectory); + } + catch (IOException e) { + log.warn(e, "Error cleaning output directory"); + } + } +} diff --git a/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/FileSystemExchangeSinkHandle.java b/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/FileSystemExchangeSinkHandle.java new file mode 100644 index 000000000000..9757c2819910 --- /dev/null +++ b/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/FileSystemExchangeSinkHandle.java @@ -0,0 +1,60 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.exchange; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.trino.spi.exchange.ExchangeSinkHandle; + +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public class FileSystemExchangeSinkHandle + implements ExchangeSinkHandle +{ + private final int partitionId; + private final Optional secretKey; + + @JsonCreator + public FileSystemExchangeSinkHandle( + @JsonProperty("partitionId") int partitionId, + @JsonProperty("secretKey") Optional secretKey) + { + this.partitionId = partitionId; + this.secretKey = requireNonNull(secretKey, "secretKey is null"); + } + + @JsonProperty + public int getPartitionId() + { + return partitionId; + } + + @JsonProperty + public Optional getSecretKey() + { + return secretKey; + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("partitionId", partitionId) + .add("secretKey", secretKey.map(value -> "[REDACTED]")) + .toString(); + } +} diff --git a/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/FileSystemExchangeSinkInstanceHandle.java b/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/FileSystemExchangeSinkInstanceHandle.java new file mode 100644 index 000000000000..0335c5716cb7 --- /dev/null +++ b/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/FileSystemExchangeSinkInstanceHandle.java @@ -0,0 +1,59 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.exchange; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.trino.spi.exchange.ExchangeSinkInstanceHandle; + +import java.net.URI; + +import static java.util.Objects.requireNonNull; + +public class FileSystemExchangeSinkInstanceHandle + implements ExchangeSinkInstanceHandle +{ + private final FileSystemExchangeSinkHandle sinkHandle; + private final URI outputDirectory; + private final int outputPartitionCount; + + @JsonCreator + public FileSystemExchangeSinkInstanceHandle( + @JsonProperty("sinkHandle") FileSystemExchangeSinkHandle sinkHandle, + @JsonProperty("outputDirectory") URI outputDirectory, + @JsonProperty("outputPartitionCount") int outputPartitionCount) + { + this.sinkHandle = requireNonNull(sinkHandle, "sinkHandle is null"); + this.outputDirectory = requireNonNull(outputDirectory, "outputDirectory is null"); + this.outputPartitionCount = outputPartitionCount; + } + + @JsonProperty + public FileSystemExchangeSinkHandle getSinkHandle() + { + return sinkHandle; + } + + @JsonProperty + public URI getOutputDirectory() + { + return outputDirectory; + } + + @JsonProperty + public int getOutputPartitionCount() + { + return outputPartitionCount; + } +} diff --git a/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/FileSystemExchangeSource.java b/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/FileSystemExchangeSource.java new file mode 100644 index 000000000000..8cd0261172fa --- /dev/null +++ b/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/FileSystemExchangeSource.java @@ -0,0 +1,126 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.exchange; + +import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slice; +import io.airlift.slice.SliceInput; +import io.trino.spi.exchange.ExchangeSource; + +import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; +import javax.crypto.SecretKey; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.net.URI; +import java.util.Iterator; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; + +import static java.util.Objects.requireNonNull; + +public class FileSystemExchangeSource + implements ExchangeSource +{ + private final FileSystemExchangeStorage exchangeStorage; + @GuardedBy("this") + private final Iterator files; + @GuardedBy("this") + private final Iterator> secretKeys; + + @GuardedBy("this") + private SliceInput sliceInput; + @GuardedBy("this") + private boolean closed; + + public FileSystemExchangeSource(FileSystemExchangeStorage exchangeStorage, List files, List> secretKeys) + { + this.exchangeStorage = requireNonNull(exchangeStorage, "exchangeStorage is null"); + this.files = ImmutableList.copyOf(requireNonNull(files, "files is null")).iterator(); + this.secretKeys = ImmutableList.copyOf(requireNonNull(secretKeys, "secretKeys is null")).stream().iterator(); + } + + @Override + public CompletableFuture isBlocked() + { + return NOT_BLOCKED; + } + + @Override + public synchronized boolean isFinished() + { + return closed || (!files.hasNext() && sliceInput == null); + } + + @Nullable + @Override + public synchronized Slice read() + { + if (isFinished()) { + return null; + } + + if (sliceInput != null && !sliceInput.isReadable()) { + sliceInput.close(); + sliceInput = null; + } + + if (sliceInput == null) { + if (files.hasNext()) { + // TODO: implement parallel read + URI file = files.next(); + Optional secretKey = secretKeys.next(); + try { + sliceInput = exchangeStorage.getSliceInput(file, secretKey); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + } + + if (sliceInput == null) { + return null; + } + + if (!sliceInput.isReadable()) { + sliceInput.close(); + sliceInput = null; + return null; + } + + int size = sliceInput.readInt(); + return sliceInput.readSlice(size); + } + + @Override + public synchronized long getSystemMemoryUsage() + { + return sliceInput != null ? sliceInput.getRetainedSize() : 0; + } + + @Override + public synchronized void close() + { + if (!closed) { + closed = true; + if (sliceInput != null) { + sliceInput.close(); + sliceInput = null; + } + } + } +} diff --git a/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/FileSystemExchangeSourceHandle.java b/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/FileSystemExchangeSourceHandle.java new file mode 100644 index 000000000000..adcae689cd81 --- /dev/null +++ b/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/FileSystemExchangeSourceHandle.java @@ -0,0 +1,63 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.exchange; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import io.trino.spi.exchange.ExchangeSourceHandle; + +import java.net.URI; +import java.util.List; +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public class FileSystemExchangeSourceHandle + implements ExchangeSourceHandle +{ + private final int partitionId; + private final List files; + private final Optional secretKey; + + @JsonCreator + public FileSystemExchangeSourceHandle( + @JsonProperty("partitionId") int partitionId, + @JsonProperty("files") List files, + @JsonProperty("secretKey") Optional secretKey) + { + this.partitionId = partitionId; + this.files = ImmutableList.copyOf(requireNonNull(files, "files is null")); + this.secretKey = requireNonNull(secretKey, "secretKey is null"); + } + + @Override + @JsonProperty + public int getPartitionId() + { + return partitionId; + } + + @JsonProperty + public List getFiles() + { + return files; + } + + @JsonProperty + public Optional getSecretKey() + { + return secretKey; + } +} diff --git a/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/FileSystemExchangeStorage.java b/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/FileSystemExchangeStorage.java new file mode 100644 index 000000000000..e55318a68a08 --- /dev/null +++ b/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/FileSystemExchangeStorage.java @@ -0,0 +1,52 @@ +package io.trino.plugin.exchange; +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import io.airlift.slice.SliceInput; + +import javax.crypto.SecretKey; + +import java.io.IOException; +import java.net.URI; +import java.util.Optional; +import java.util.stream.Stream; + +public interface FileSystemExchangeStorage + extends AutoCloseable +{ + void initialize(URI baseDirectory); + + void createDirectories(URI dir) throws IOException; + + SliceInput getSliceInput(URI file, Optional secretKey) throws IOException; + + ExchangeStorageWriter createExchangeStorageWriter(URI file, Optional secretKey) throws IOException; + + boolean exists(URI file) throws IOException; + + void createEmptyFile(URI file) throws IOException; + + void deleteRecursively(URI dir) throws IOException; + + Stream listFiles(URI dir) throws IOException; + + Stream listDirectories(URI dir) throws IOException; + + long size(URI file, Optional secretKey) throws IOException; + + int getWriteBufferSizeInBytes(); + + @Override + void close() throws IOException; +} diff --git a/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/local/LocalFileSystemExchangeStorage.java b/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/local/LocalFileSystemExchangeStorage.java new file mode 100644 index 000000000000..6be926cc426f --- /dev/null +++ b/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/local/LocalFileSystemExchangeStorage.java @@ -0,0 +1,201 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.exchange.local; + +import com.google.common.collect.ImmutableList; +import com.google.common.io.MoreFiles; +import com.google.common.util.concurrent.ListenableFuture; +import io.airlift.slice.InputStreamSliceInput; +import io.airlift.slice.Slice; +import io.airlift.slice.SliceInput; +import io.airlift.units.DataSize; +import io.trino.plugin.exchange.ExchangeStorageWriter; +import io.trino.plugin.exchange.FileSystemExchangeStorage; +import io.trino.spi.TrinoException; + +import javax.crypto.Cipher; +import javax.crypto.CipherInputStream; +import javax.crypto.CipherOutputStream; +import javax.crypto.NoSuchPaddingException; +import javax.crypto.SecretKey; + +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.net.URI; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; +import java.util.List; +import java.util.Optional; +import java.util.function.Predicate; +import java.util.stream.Stream; + +import static com.google.common.io.RecursiveDeleteOption.ALLOW_INSECURE; +import static com.google.common.util.concurrent.Futures.immediateVoidFuture; +import static io.airlift.units.DataSize.Unit.KILOBYTE; +import static io.trino.plugin.exchange.FileSystemExchangeManager.AES; +import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; +import static java.lang.Math.toIntExact; +import static java.nio.file.Files.createFile; + +public class LocalFileSystemExchangeStorage + implements FileSystemExchangeStorage +{ + private static final int BUFFER_SIZE_IN_BYTES = toIntExact(DataSize.of(4, KILOBYTE).toBytes()); + + @Override + public void initialize(URI baseDirectory) + { + // no need to do anything for local file system + } + + @Override + public void createDirectories(URI dir) + throws IOException + { + Files.createDirectories(Paths.get(dir.getPath())); + } + + @Override + public SliceInput getSliceInput(URI file, Optional secretKey) + throws IOException + { + if (secretKey.isPresent()) { + try { + final Cipher cipher = Cipher.getInstance(AES); + cipher.init(Cipher.DECRYPT_MODE, secretKey.get()); + return new InputStreamSliceInput(new CipherInputStream(new FileInputStream(Paths.get(file.getPath()).toFile()), cipher), BUFFER_SIZE_IN_BYTES); + } + catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException e) { + throw new TrinoException(GENERIC_INTERNAL_ERROR, "Failed to create CipherInputStream " + e.getMessage(), e); + } + } + else { + return new InputStreamSliceInput(new FileInputStream(Paths.get(file.getPath()).toFile()), BUFFER_SIZE_IN_BYTES); + } + } + + @Override + public ExchangeStorageWriter createExchangeStorageWriter(URI file, Optional secretKey) + throws IOException + { + return new LocalExchangeStorageWriter(file, secretKey); + } + + @Override + public boolean exists(URI file) + { + return Files.exists(Paths.get(file.getPath())); + } + + @Override + public void createEmptyFile(URI file) + throws IOException + { + createFile(Paths.get(file.getPath())); + } + + @Override + public void deleteRecursively(URI dir) + throws IOException + { + MoreFiles.deleteRecursively(Paths.get(dir.getPath()), ALLOW_INSECURE); + } + + @Override + public Stream listFiles(URI dir) + throws IOException + { + return listPaths(dir, Files::isRegularFile).stream(); + } + + @Override + public Stream listDirectories(URI dir) + throws IOException + { + return listPaths(dir, Files::isDirectory).stream(); + } + + @Override + public long size(URI file, Optional secretKey) + throws IOException + { + return Files.size(Paths.get(file.getPath())); + } + + @Override + public int getWriteBufferSizeInBytes() + { + return BUFFER_SIZE_IN_BYTES; + } + + @Override + public void close() + { + } + + private static List listPaths(URI directory, Predicate predicate) + throws IOException + { + ImmutableList.Builder builder = ImmutableList.builder(); + try (Stream dir = Files.list(Paths.get(directory.getPath()))) { + dir.filter(predicate).map(Path::toUri).forEach(builder::add); + } + return builder.build(); + } + + private static class LocalExchangeStorageWriter + implements ExchangeStorageWriter + { + private final OutputStream outputStream; + + public LocalExchangeStorageWriter(URI file, Optional secretKey) + throws FileNotFoundException + { + if (secretKey.isPresent()) { + try { + final Cipher cipher = Cipher.getInstance(AES); + cipher.init(Cipher.ENCRYPT_MODE, secretKey.get()); + this.outputStream = new CipherOutputStream(new FileOutputStream(Paths.get(file.getPath()).toFile()), cipher); + } + catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException e) { + throw new TrinoException(GENERIC_INTERNAL_ERROR, "Failed to create CipherOutputStream " + e.getMessage(), e); + } + } + else { + this.outputStream = new FileOutputStream(Paths.get(file.getPath()).toFile()); + } + } + + @Override + public ListenableFuture write(Slice slice) + throws IOException + { + outputStream.write(slice.getBytes()); + return immediateVoidFuture(); + } + + @Override + public void close() + throws IOException + { + outputStream.close(); + } + } +} diff --git a/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/s3/DirectByteArrayAsyncRequestBody.java b/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/s3/DirectByteArrayAsyncRequestBody.java new file mode 100644 index 000000000000..6da8cf5336fa --- /dev/null +++ b/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/s3/DirectByteArrayAsyncRequestBody.java @@ -0,0 +1,114 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.exchange.s3; + +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.internal.async.ByteArrayAsyncRequestBody; +import software.amazon.awssdk.core.internal.util.Mimetype; +import software.amazon.awssdk.utils.Logger; + +import java.nio.ByteBuffer; +import java.util.Optional; + +/** + * This class mimics the implementation of {@link ByteArrayAsyncRequestBody} except for we use a ByteBuffer + * to avoid unnecessary memory copies + * + * An implementation of {@link AsyncRequestBody} for providing data from memory. This is created using static + * methods on {@link AsyncRequestBody} + * + * @see AsyncRequestBody#fromBytes(byte[]) + * @see AsyncRequestBody#fromByteBuffer(ByteBuffer) + * @see AsyncRequestBody#fromString(String) + */ +public final class DirectByteArrayAsyncRequestBody + implements AsyncRequestBody +{ + private static final Logger log = Logger.loggerFor(DirectByteArrayAsyncRequestBody.class); + + private final ByteBuffer byteBuffer; + + private final String mimetype; + + public DirectByteArrayAsyncRequestBody(ByteBuffer byteBuffer, String mimetype) + { + this.byteBuffer = byteBuffer; + this.mimetype = mimetype; + } + + @Override + public Optional contentLength() + { + return Optional.of((long) byteBuffer.capacity()); + } + + @Override + public String contentType() + { + return mimetype; + } + + @Override + public void subscribe(Subscriber s) + { + // As per rule 1.9 we must throw NullPointerException if the subscriber parameter is null + if (s == null) { + throw new NullPointerException("Subscription MUST NOT be null."); + } + + // As per 2.13, this method must return normally (i.e. not throw). + try { + s.onSubscribe( + new Subscription() { + private boolean done; + + @Override + public void request(long n) + { + if (done) { + return; + } + if (n > 0) { + done = true; + s.onNext(byteBuffer); + s.onComplete(); + } + else { + s.onError(new IllegalArgumentException("§3.9: non-positive requests are not allowed!")); + } + } + + @Override + public void cancel() + { + synchronized (this) { + if (!done) { + done = true; + } + } + } + }); + } + catch (Throwable ex) { + log.error(() -> s + " violated the Reactive Streams rule 2.13 by throwing an exception from onSubscribe.", ex); + } + } + + static AsyncRequestBody fromByteBuffer(ByteBuffer byteBuffer) + { + return new DirectByteArrayAsyncRequestBody(byteBuffer, Mimetype.MIMETYPE_OCTET_STREAM); + } +} diff --git a/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/s3/ExchangeS3Config.java b/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/s3/ExchangeS3Config.java new file mode 100644 index 000000000000..60741ed6d252 --- /dev/null +++ b/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/s3/ExchangeS3Config.java @@ -0,0 +1,129 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.exchange.s3; + +import io.airlift.configuration.Config; +import io.airlift.configuration.ConfigDescription; +import io.airlift.configuration.ConfigSecuritySensitive; +import io.airlift.units.DataSize; +import io.airlift.units.MaxDataSize; +import io.airlift.units.MinDataSize; + +import javax.validation.constraints.Min; +import javax.validation.constraints.NotNull; + +import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static java.lang.Math.max; + +public class ExchangeS3Config +{ + private String s3AwsAccessKey; + private String s3AwsSecretKey; + private String s3Region; + private String s3Endpoint; + private int s3MaxErrorRetries = 10; + private int deletionThreadCount = max(1, Runtime.getRuntime().availableProcessors() / 2); + private DataSize s3UploadPartSize = DataSize.of(5, MEGABYTE); + + public String getS3AwsAccessKey() + { + return s3AwsAccessKey; + } + + @Config("exchange.s3.aws-access-key") + public ExchangeS3Config setS3AwsAccessKey(String s3AwsAccessKey) + { + this.s3AwsAccessKey = s3AwsAccessKey; + return this; + } + + public String getS3AwsSecretKey() + { + return s3AwsSecretKey; + } + + @Config("exchange.s3.aws-secret-key") + @ConfigSecuritySensitive + public ExchangeS3Config setS3AwsSecretKey(String s3AwsSecretKey) + { + this.s3AwsSecretKey = s3AwsSecretKey; + return this; + } + + public String getS3Region() + { + return s3Region; + } + + @Config("exchange.s3.region") + public ExchangeS3Config setS3Region(String s3Region) + { + this.s3Region = s3Region; + return this; + } + + public String getS3Endpoint() + { + return s3Endpoint; + } + + @Config("exchange.s3.endpoint") + public ExchangeS3Config setS3Endpoint(String s3Endpoint) + { + this.s3Endpoint = s3Endpoint; + return this; + } + + @Min(0) + public int getS3MaxErrorRetries() + { + return s3MaxErrorRetries; + } + + @Config("exchange.s3.max-error-retries") + public ExchangeS3Config setS3MaxErrorRetries(int s3MaxErrorRetries) + { + this.s3MaxErrorRetries = s3MaxErrorRetries; + return this; + } + + @Min(1) + public int getDeletionThreadCount() + { + return deletionThreadCount; + } + + @Config("exchange.s3.deletion.thread-count") + public ExchangeS3Config setDeletionThreadCount(int deletionThreadCount) + { + this.deletionThreadCount = deletionThreadCount; + return this; + } + + @NotNull + @MinDataSize("5MB") + @MaxDataSize("256MB") + public DataSize getS3UploadPartSize() + { + return s3UploadPartSize; + } + + @Config("exchange.s3.upload.part-size") + @ConfigDescription("Part size for S3 multi-part upload") + public ExchangeS3Config setS3UploadPartSize(DataSize s3UploadPartSize) + { + this.s3UploadPartSize = s3UploadPartSize; + return this; + } +} diff --git a/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/s3/S3FileSystemExchangeStorage.java b/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/s3/S3FileSystemExchangeStorage.java new file mode 100644 index 000000000000..9b97f586eab9 --- /dev/null +++ b/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/s3/S3FileSystemExchangeStorage.java @@ -0,0 +1,573 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.exchange.s3; + +import com.google.common.collect.ImmutableList; +import com.google.common.io.Closer; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.ThreadFactoryBuilder; +import io.airlift.slice.InputStreamSliceInput; +import io.airlift.slice.Slice; +import io.airlift.slice.SliceInput; +import io.trino.plugin.exchange.ExchangeStorageWriter; +import io.trino.plugin.exchange.FileSystemExchangeStorage; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.awscore.exception.AwsServiceException; +import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; +import software.amazon.awssdk.core.internal.util.Mimetype; +import software.amazon.awssdk.core.retry.RetryPolicy; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.core.sync.ResponseTransformer; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.S3AsyncClientBuilder; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.S3ClientBuilder; +import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CompletedMultipartUpload; +import software.amazon.awssdk.services.s3.model.CompletedPart; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.Delete; +import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; +import software.amazon.awssdk.services.s3.model.DeleteObjectsRequest; +import software.amazon.awssdk.services.s3.model.ExpirationStatus; +import software.amazon.awssdk.services.s3.model.GetBucketLifecycleConfigurationRequest; +import software.amazon.awssdk.services.s3.model.GetBucketLifecycleConfigurationResponse; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.HeadObjectRequest; +import software.amazon.awssdk.services.s3.model.HeadObjectResponse; +import software.amazon.awssdk.services.s3.model.LifecycleRuleFilter; +import software.amazon.awssdk.services.s3.model.ListObjectsV2Request; +import software.amazon.awssdk.services.s3.model.NoSuchKeyException; +import software.amazon.awssdk.services.s3.model.ObjectIdentifier; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.S3Object; +import software.amazon.awssdk.services.s3.model.UploadPartRequest; +import software.amazon.awssdk.services.s3.paginators.ListObjectsV2Iterable; + +import javax.annotation.PreDestroy; +import javax.crypto.SecretKey; +import javax.inject.Inject; + +import java.io.ByteArrayInputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.stream.Stream; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Strings.nullToEmpty; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.util.concurrent.Futures.immediateVoidFuture; +import static io.airlift.concurrent.MoreFutures.asVoid; +import static io.airlift.concurrent.MoreFutures.toListenableFuture; +import static io.trino.plugin.exchange.FileSystemExchangeManager.PATH_SEPARATOR; +import static java.lang.Math.toIntExact; +import static java.util.Locale.ENGLISH; +import static java.util.Objects.requireNonNull; +import static software.amazon.awssdk.core.client.config.SdkAdvancedClientOption.USER_AGENT_PREFIX; +import static software.amazon.awssdk.core.client.config.SdkAdvancedClientOption.USER_AGENT_SUFFIX; +import static software.amazon.awssdk.core.sync.RequestBody.fromContentProvider; + +public class S3FileSystemExchangeStorage + implements FileSystemExchangeStorage +{ + private static final String DIRECTORY_SUFFIX = "_$folder$"; + + private final Region region; + private final String endpoint; + private final int multiUploadPartSize; + private final S3Client s3Client; + private final S3AsyncClient s3AsyncClient; + + private final ExecutorService deleteExecutor; + + @Inject + public S3FileSystemExchangeStorage(ExchangeS3Config config) + { + if (config.getS3Region() != null) { + this.region = Region.of(config.getS3Region().toLowerCase(ENGLISH)); + } + else { + this.region = null; + } + this.endpoint = config.getS3Endpoint(); + this.multiUploadPartSize = toIntExact(config.getS3UploadPartSize().toBytes()); + + AwsCredentialsProvider credentialsProvider = createAwsCredentialsProvider(config); + RetryPolicy retryPolicy = RetryPolicy.builder() + .numRetries(config.getS3MaxErrorRetries()) + .build(); + ClientOverrideConfiguration overrideConfig = ClientOverrideConfiguration.builder() + .retryPolicy(retryPolicy) + .putAdvancedOption(USER_AGENT_PREFIX, "") + .putAdvancedOption(USER_AGENT_SUFFIX, "Trino-exchange") + .build(); + + this.s3Client = createS3Client(credentialsProvider, overrideConfig); + this.s3AsyncClient = createS3AsyncClient(credentialsProvider, overrideConfig); + this.deleteExecutor = Executors.newFixedThreadPool( + config.getDeletionThreadCount(), + new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat("exchange-s3-deletion-%d") + .build()); + } + + @Override + public void initialize(URI baseDirectory) + { + // TODO: decide if we want to check for expiration life cycle rules + String bucketName = getBucketName(baseDirectory); + GetBucketLifecycleConfigurationRequest request = GetBucketLifecycleConfigurationRequest.builder() + .bucket(bucketName) + .build(); + GetBucketLifecycleConfigurationResponse response = s3Client.getBucketLifecycleConfiguration(request); + + verify(response.rules().stream().anyMatch( + rule -> rule.expiration() != null && + rule.abortIncompleteMultipartUpload() != null && + rule.status().equals(ExpirationStatus.ENABLED) && + rule.filter().equals(LifecycleRuleFilter.builder().build()) + ), "Expected file expiration and abortIncompleteMultipartUpload lifecycle rule for exchange bucket %s", baseDirectory.toString()); + } + + @Override + public void createDirectories(URI dir) + throws IOException + { + // no need to do anything for S3 + } + + @Override + public SliceInput getSliceInput(URI file, Optional secretKey) + throws IOException + { + GetObjectRequest.Builder getObjectRequestBuilder = GetObjectRequest.builder() + .bucket(getBucketName(file)) + .key(keyFromUri(file)); + S3RequestUtil.configureEncryption(secretKey, getObjectRequestBuilder); + + try { + return new InputStreamSliceInput(s3Client.getObject(getObjectRequestBuilder.build(), ResponseTransformer.toInputStream())); + } + catch (AwsServiceException e) { + throw new IOException(e); + } + } + + @Override + public ExchangeStorageWriter createExchangeStorageWriter(URI file, Optional secretKey) + { + String bucketName = getBucketName(file); + String key = keyFromUri(file); + + return new S3ExchangeStorageWriter(s3Client, s3AsyncClient, bucketName, key, multiUploadPartSize, secretKey); + } + + @Override + public boolean exists(URI file) + throws IOException + { + // Only used for commit marker files and doesn't need secretKey + return headObject(file, Optional.empty()) != null; + } + + @Override + public void createEmptyFile(URI file) + throws IOException + { + PutObjectRequest request = PutObjectRequest.builder() + .bucket(getBucketName(file)) + .key(keyFromUri(file)) + .build(); + + try { + s3Client.putObject(request, RequestBody.empty()); + } + catch (AwsServiceException e) { + throw new IOException(e); + } + } + + @Override + public void deleteRecursively(URI uri) + { + deleteExecutor.submit(() -> { + if (isDirectory(uri)) { + ImmutableList.Builder keys = ImmutableList.builder(); + for (S3Object s3Object : listObjectsRecursively(uri).contents()) { + keys.add(s3Object.key()); + } + keys.add(keyFromUri(uri) + DIRECTORY_SUFFIX); + + deleteObjects(getBucketName(uri), keys.build()); + } + else { + deleteObject(getBucketName(uri), keyFromUri(uri)); + } + }); + } + + @Override + public Stream listFiles(URI dir) + { + return listObjects(dir).contents().stream().filter(object -> !object.key().endsWith(PATH_SEPARATOR)).map(object -> { + try { + return new URI(dir.getScheme(), dir.getHost(), PATH_SEPARATOR + object.key(), dir.getFragment()); + } + catch (URISyntaxException e) { + throw new IllegalArgumentException(e); + } + }); + } + + @Override + public Stream listDirectories(URI dir) + { + return listObjects(dir).commonPrefixes().stream().map(prefix -> { + try { + return new URI(dir.getScheme(), dir.getHost(), PATH_SEPARATOR + prefix.prefix(), dir.getFragment()); + } + catch (URISyntaxException e) { + throw new IllegalArgumentException(e); + } + }); + } + + @Override + public long size(URI uri, Optional secretKey) + throws IOException + { + checkArgument(!isDirectory(uri), "expected a file URI but got a directory URI"); + HeadObjectResponse response = headObject(uri, secretKey); + if (response == null) { + throw new FileNotFoundException("File does not exist: " + uri); + } + return response.contentLength(); + } + + @Override + public int getWriteBufferSizeInBytes() + { + return multiUploadPartSize; + } + + @PreDestroy + @Override + public void close() + throws IOException + { + try (Closer closer = Closer.create()) { + closer.register(deleteExecutor::shutdown); + closer.register(s3Client::close); + closer.register(s3AsyncClient::close); + } + } + + private HeadObjectResponse headObject(URI uri, Optional secretKey) + throws IOException + { + HeadObjectRequest.Builder headObjectRequestBuilder = HeadObjectRequest.builder() + .bucket(getBucketName(uri)) + .key(keyFromUri(uri)); + S3RequestUtil.configureEncryption(secretKey, headObjectRequestBuilder); + + try { + return s3Client.headObject(headObjectRequestBuilder.build()); + } + catch (AwsServiceException e) { + if (e instanceof NoSuchKeyException) { + return null; + } + throw new IOException(e); + } + } + + private ListObjectsV2Iterable listObjects(URI dir) + { + String key = keyFromUri(dir); + if (!key.isEmpty()) { + key += PATH_SEPARATOR; + } + + ListObjectsV2Request request = ListObjectsV2Request.builder() + .bucket(getBucketName(dir)) + .prefix(key) + .delimiter(PATH_SEPARATOR) + .build(); + + return s3Client.listObjectsV2Paginator(request); + } + + private ListObjectsV2Iterable listObjectsRecursively(URI dir) + { + ListObjectsV2Request request = ListObjectsV2Request.builder() + .bucket(getBucketName(dir)) + .prefix(keyFromUri(dir)) + .build(); + + return s3Client.listObjectsV2Paginator(request); + } + + private void deleteObject(String bucketName, String key) + { + DeleteObjectRequest request = DeleteObjectRequest.builder() + .bucket(bucketName) + .key(key) + .build(); + s3Client.deleteObject(request); + } + + private void deleteObjects(String bucketName, List keys) + { + DeleteObjectsRequest request = DeleteObjectsRequest.builder() + .bucket(bucketName) + .delete(Delete.builder().objects(keys.stream().map(s -> ObjectIdentifier.builder().key(s).build()).collect(toImmutableList())).build()) + .build(); + s3Client.deleteObjects(request); + } + + /** + * Helper function used to work around the fact that if you use an S3 bucket with an '_' that java.net.URI + * behaves differently and sets the host value to null whereas S3 buckets without '_' have a properly + * set host field. '_' is only allowed in S3 bucket names in us-east-1. + * + * @param uri The URI from which to extract a host value. + * @return The host value where uri.getAuthority() is used when uri.getHost() returns null as long as no UserInfo is present. + * @throws IllegalArgumentException If the bucket cannot be determined from the URI. + */ + private static String getBucketName(URI uri) + { + if (uri.getHost() != null) { + return uri.getHost(); + } + + if (uri.getUserInfo() == null) { + return uri.getAuthority(); + } + + throw new IllegalArgumentException("Unable to determine S3 bucket from URI."); + } + + private static String keyFromUri(URI uri) + { + checkArgument(uri.isAbsolute(), "Uri is not absolute: %s", uri); + String key = nullToEmpty(uri.getPath()); + if (key.startsWith(PATH_SEPARATOR)) { + key = key.substring(PATH_SEPARATOR.length()); + } + if (key.endsWith(PATH_SEPARATOR)) { + key = key.substring(0, key.length() - PATH_SEPARATOR.length()); + } + return key; + } + + private static boolean isDirectory(URI uri) + { + return uri.toString().endsWith(PATH_SEPARATOR); + } + + private static AwsCredentialsProvider createAwsCredentialsProvider(ExchangeS3Config config) + { + if (config.getS3AwsAccessKey() != null && config.getS3AwsSecretKey() != null) { + return StaticCredentialsProvider.create(AwsBasicCredentials.create(config.getS3AwsAccessKey(), config.getS3AwsSecretKey())); + } + return DefaultCredentialsProvider.create(); + } + + private S3Client createS3Client(AwsCredentialsProvider credentialsProvider, ClientOverrideConfiguration overrideConfig) + { + S3ClientBuilder clientBuilder = S3Client.builder() + .credentialsProvider(credentialsProvider) + .overrideConfiguration(overrideConfig); + + if (region != null) { + clientBuilder = clientBuilder.region(region); + } + if (endpoint != null) { + clientBuilder = clientBuilder.endpointOverride(URI.create(endpoint)); + } + + return clientBuilder.build(); + } + + private S3AsyncClient createS3AsyncClient(AwsCredentialsProvider credentialsProvider, ClientOverrideConfiguration overrideConfig) + { + S3AsyncClientBuilder clientBuilder = S3AsyncClient.builder() + .credentialsProvider(credentialsProvider) + .overrideConfiguration(overrideConfig); + + if (region != null) { + clientBuilder = clientBuilder.region(region); + } + if (endpoint != null) { + clientBuilder = clientBuilder.endpointOverride(URI.create(endpoint)); + } + + return clientBuilder.build(); + } + + private static class S3ExchangeStorageWriter + implements ExchangeStorageWriter + { + private final S3Client s3Client; + private final S3AsyncClient s3AsyncClient; + private final String bucketName; + private final String key; + private final int partSize; + private final Optional secretKey; + + private int currentPartNumber; + private Optional uploadId = Optional.empty(); + private final List> uploadFutures = new ArrayList<>(); + + public S3ExchangeStorageWriter(S3Client s3Client, S3AsyncClient s3AsyncClient, String bucketName, String key, int partSize, Optional secretKey) + { + this.s3Client = requireNonNull(s3Client, "s3Client is null"); + this.s3AsyncClient = requireNonNull(s3AsyncClient, "s3AsyncClient is null"); + this.bucketName = requireNonNull(bucketName, "bucketName is null"); + this.key = requireNonNull(key, "key is null"); + this.partSize = partSize; + this.secretKey = requireNonNull(secretKey, "secretKey is null"); + } + + @Override + public ListenableFuture write(Slice slice) + throws IOException + { + // skip multipart upload if there would only be one part + if (slice.length() < partSize && uploadId.isEmpty()) { + PutObjectRequest.Builder putObjectRequestBuilder = PutObjectRequest.builder() + .bucket(bucketName) + .key(key); + S3RequestUtil.configureEncryption(secretKey, putObjectRequestBuilder); + + try { + s3Client.putObject(putObjectRequestBuilder.build(), + // avoid extra memory copy + fromContentProvider(() -> new ByteArrayInputStream(slice.getBytes()), slice.length(), Mimetype.MIMETYPE_OCTET_STREAM)); + return immediateVoidFuture(); + } + catch (AwsServiceException e) { + throw new IOException(e); + } + } + + if (uploadId.isEmpty()) { + uploadId = Optional.of(createMultipartUpload().uploadId()); + } + + CompletableFuture uploadFuture = uploadPart(uploadId.get(), slice); + uploadFutures.add(uploadFuture); + + return asVoid(toListenableFuture(uploadFuture)); + } + + @Override + public void close() + throws IOException + { + if (uploadId.isEmpty()) { + return; + } + + try { + List completedParts = uploadFutures.stream() + .map(CompletableFuture::join) + .sorted(Comparator.comparing(CompletedPart::partNumber)) + .collect(toImmutableList()); + completeMultiUpload(uploadId.get(), completedParts); + } + catch (RuntimeException e) { + abortUploadSuppressed(uploadId.get(), e); + throw new IOException(e); + } + } + + private CreateMultipartUploadResponse createMultipartUpload() + { + CreateMultipartUploadRequest.Builder createMultipartUploadRequestBuilder = CreateMultipartUploadRequest.builder() + .bucket(bucketName) + .key(key); + S3RequestUtil.configureEncryption(secretKey, createMultipartUploadRequestBuilder); + return s3Client.createMultipartUpload(createMultipartUploadRequestBuilder.build()); + } + + private CompletableFuture uploadPart(String uploadId, Slice slice) + { + currentPartNumber++; + UploadPartRequest.Builder uploadPartRequestBuilder = UploadPartRequest.builder() + .bucket(bucketName) + .key(key) + .uploadId(uploadId) + .partNumber(currentPartNumber); + S3RequestUtil.configureEncryption(secretKey, uploadPartRequestBuilder); + UploadPartRequest uploadPartRequest = uploadPartRequestBuilder.build(); + return s3AsyncClient.uploadPart(uploadPartRequest, DirectByteArrayAsyncRequestBody.fromByteBuffer(slice.toByteBuffer())) + .thenApply(uploadPartResponse -> CompletedPart.builder().eTag(uploadPartResponse.eTag()).partNumber(uploadPartRequest.partNumber()).build()); + } + + private void completeMultiUpload(String uploadId, List completedParts) + { + CompletedMultipartUpload completedMultipartUpload = CompletedMultipartUpload.builder() + .parts(completedParts) + .build(); + CompleteMultipartUploadRequest completeMultipartUploadRequest = CompleteMultipartUploadRequest.builder() + .bucket(bucketName) + .key(key) + .uploadId(uploadId) + .multipartUpload(completedMultipartUpload) + .build(); + s3Client.completeMultipartUpload(completeMultipartUploadRequest); + } + + private void abortUpload(String uploadId) + { + AbortMultipartUploadRequest abortMultipartUploadRequest = AbortMultipartUploadRequest.builder() + .bucket(bucketName) + .key(key) + .uploadId(uploadId) + .build(); + s3Client.abortMultipartUpload(abortMultipartUploadRequest); + } + + @SuppressWarnings("ObjectEquality") + private void abortUploadSuppressed(String uploadId, Throwable throwable) + { + try { + abortUpload(uploadId); + } + catch (Throwable t) { + if (throwable != t) { + throwable.addSuppressed(t); + } + } + } + } +} diff --git a/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/s3/S3RequestUtil.java b/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/s3/S3RequestUtil.java new file mode 100644 index 000000000000..688326bc48ad --- /dev/null +++ b/plugin/trino-exchange/src/main/java/io/trino/plugin/exchange/s3/S3RequestUtil.java @@ -0,0 +1,73 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.exchange.s3; + +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.HeadObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.ServerSideEncryption; +import software.amazon.awssdk.services.s3.model.UploadPartRequest; +import software.amazon.awssdk.utils.Md5Utils; + +import javax.crypto.SecretKey; + +import java.util.Base64; +import java.util.Optional; +import java.util.function.Consumer; + +public final class S3RequestUtil +{ + private S3RequestUtil() + { + } + + static void configureEncryption(Optional secretKey, PutObjectRequest.Builder requestBuilder) + { + configureEncryption(secretKey, requestBuilder::sseCustomerAlgorithm, requestBuilder::sseCustomerKey, requestBuilder::sseCustomerKeyMD5); + } + + static void configureEncryption(Optional secretKey, CreateMultipartUploadRequest.Builder requestBuilder) + { + configureEncryption(secretKey, requestBuilder::sseCustomerAlgorithm, requestBuilder::sseCustomerKey, requestBuilder::sseCustomerKeyMD5); + } + + static void configureEncryption(Optional secretKey, UploadPartRequest.Builder requestBuilder) + { + configureEncryption(secretKey, requestBuilder::sseCustomerAlgorithm, requestBuilder::sseCustomerKey, requestBuilder::sseCustomerKeyMD5); + } + + static void configureEncryption(Optional secretKey, GetObjectRequest.Builder requestBuilder) + { + configureEncryption(secretKey, requestBuilder::sseCustomerAlgorithm, requestBuilder::sseCustomerKey, requestBuilder::sseCustomerKeyMD5); + } + + static void configureEncryption(Optional secretKey, HeadObjectRequest.Builder requestBuilder) + { + configureEncryption(secretKey, requestBuilder::sseCustomerAlgorithm, requestBuilder::sseCustomerKey, requestBuilder::sseCustomerKeyMD5); + } + + static void configureEncryption( + Optional secretKey, + Consumer customAlgorithmSetter, + Consumer customKeySetter, + Consumer customMd5Setter) + { + secretKey.ifPresent(key -> { + customAlgorithmSetter.accept(ServerSideEncryption.AES256.name()); + customKeySetter.accept(Base64.getEncoder().encodeToString(key.getEncoded())); + customMd5Setter.accept(Md5Utils.md5AsBase64(key.getEncoded())); + }); + } +} diff --git a/plugin/trino-exchange/src/test/java/io/trino/plugin/exchange/TestFileSystemExchangeConfig.java b/plugin/trino-exchange/src/test/java/io/trino/plugin/exchange/TestFileSystemExchangeConfig.java new file mode 100644 index 000000000000..f5a408c03713 --- /dev/null +++ b/plugin/trino-exchange/src/test/java/io/trino/plugin/exchange/TestFileSystemExchangeConfig.java @@ -0,0 +1,49 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.exchange; + +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Map; + +import static io.airlift.configuration.testing.ConfigAssertions.assertFullMapping; +import static io.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; +import static io.airlift.configuration.testing.ConfigAssertions.recordDefaults; + +public class TestFileSystemExchangeConfig +{ + @Test + public void testDefaults() + { + assertRecordedDefaults(recordDefaults(FileSystemExchangeConfig.class) + .setBaseDirectory(null) + .setExchangeEncryptionEnabled(false)); + } + + @Test + public void testExplicitPropertyMappings() + { + Map properties = new ImmutableMap.Builder() + .put("exchange.base-directory", "s3n://exchange-spooling-test/") + .put("exchange.encryption-enabled", "true") + .build(); + + FileSystemExchangeConfig expected = new FileSystemExchangeConfig() + .setBaseDirectory("s3n://exchange-spooling-test/") + .setExchangeEncryptionEnabled(true); + + assertFullMapping(properties, expected); + } +} diff --git a/plugin/trino-exchange/src/test/java/io/trino/plugin/exchange/containers/MinioStorage.java b/plugin/trino-exchange/src/test/java/io/trino/plugin/exchange/containers/MinioStorage.java new file mode 100644 index 000000000000..5b411c00755e --- /dev/null +++ b/plugin/trino-exchange/src/test/java/io/trino/plugin/exchange/containers/MinioStorage.java @@ -0,0 +1,101 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.exchange.containers; + +import com.google.common.collect.ImmutableMap; +import io.trino.testing.containers.Minio; +import io.trino.util.AutoCloseableCloser; +import org.testcontainers.containers.Network; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.CreateBucketRequest; + +import java.net.URI; + +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; +import static org.testcontainers.containers.Network.newNetwork; +import static software.amazon.awssdk.regions.Region.US_EAST_1; + +public class MinioStorage + implements AutoCloseable +{ + public static final String ACCESS_KEY = "accesskey"; + public static final String SECRET_KEY = "secretkey"; + + private final String bucketName; + private final Minio minio; + + private final AutoCloseableCloser closer = AutoCloseableCloser.create(); + + private State state = State.INITIAL; + + public MinioStorage(String bucketName) + { + this.bucketName = requireNonNull(bucketName, "bucketName is null"); + Network network = closer.register(newNetwork()); + this.minio = closer.register( + Minio.builder() + .withNetwork(network) + .withEnvVars(ImmutableMap.builder() + .put("MINIO_ACCESS_KEY", ACCESS_KEY) + .put("MINIO_SECRET_KEY", SECRET_KEY) + .build()) + .build()); + } + + public void start() + { + checkState(state == State.INITIAL, "Already started: %s", state); + state = State.STARTING; + minio.start(); + S3Client s3Client = S3Client.builder() + .endpointOverride(URI.create("http://localhost:" + minio.getMinioApiEndpoint().getPort())) + .credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create(ACCESS_KEY, SECRET_KEY))) + .region(US_EAST_1) + .build(); + CreateBucketRequest createBucketRequest = CreateBucketRequest.builder() + .bucket(bucketName) + .build(); + s3Client.createBucket(createBucketRequest); + state = State.STARTED; + } + + public Minio getMinio() + { + return minio; + } + + public String getBucketName() + { + return bucketName; + } + + @Override + public void close() + throws Exception + { + closer.close(); + state = State.STOPPED; + } + + private enum State + { + INITIAL, + STARTING, + STARTED, + STOPPED, + } +} diff --git a/plugin/trino-exchange/src/test/java/io/trino/plugin/exchange/local/TestLocalFileSystemExchangeManager.java b/plugin/trino-exchange/src/test/java/io/trino/plugin/exchange/local/TestLocalFileSystemExchangeManager.java new file mode 100644 index 000000000000..e37dc70fa1be --- /dev/null +++ b/plugin/trino-exchange/src/test/java/io/trino/plugin/exchange/local/TestLocalFileSystemExchangeManager.java @@ -0,0 +1,31 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.exchange.local; + +import com.google.common.collect.ImmutableMap; +import io.trino.plugin.exchange.FileSystemExchangeManagerFactory; +import io.trino.spi.exchange.ExchangeManager; +import io.trino.testing.AbstractTestExchangeManager; + +public class TestLocalFileSystemExchangeManager + extends AbstractTestExchangeManager +{ + @Override + protected ExchangeManager createExchangeManager() + { + return new FileSystemExchangeManagerFactory().create(ImmutableMap.of( + "exchange.base-directory", System.getProperty("java.io.tmpdir") + "/trino-local-file-system-exchange-manager", + "exchange.encryption-enabled", "true")); + } +} diff --git a/plugin/trino-exchange/src/test/java/io/trino/plugin/exchange/s3/TestExchangeS3Config.java b/plugin/trino-exchange/src/test/java/io/trino/plugin/exchange/s3/TestExchangeS3Config.java new file mode 100644 index 000000000000..26b2b3f87b69 --- /dev/null +++ b/plugin/trino-exchange/src/test/java/io/trino/plugin/exchange/s3/TestExchangeS3Config.java @@ -0,0 +1,67 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.exchange.s3; + +import com.google.common.collect.ImmutableMap; +import io.airlift.units.DataSize; +import org.testng.annotations.Test; + +import java.util.Map; + +import static io.airlift.configuration.testing.ConfigAssertions.assertFullMapping; +import static io.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; +import static io.airlift.configuration.testing.ConfigAssertions.recordDefaults; +import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static java.lang.Math.max; + +public class TestExchangeS3Config +{ + @Test + public void testDefaults() + { + assertRecordedDefaults(recordDefaults(ExchangeS3Config.class) + .setS3AwsAccessKey(null) + .setS3AwsSecretKey(null) + .setS3Region(null) + .setS3Endpoint(null) + .setS3MaxErrorRetries(10) + .setDeletionThreadCount(max(1, Runtime.getRuntime().availableProcessors() / 2)) + .setS3UploadPartSize(DataSize.of(5, MEGABYTE))); + } + + @Test + public void testExplicitPropertyMappings() + { + Map properties = new ImmutableMap.Builder() + .put("exchange.s3.aws-access-key", "access") + .put("exchange.s3.aws-secret-key", "secret") + .put("exchange.s3.region", "us-west-1") + .put("exchange.s3.endpoint", "https://s3.us-east-1.amazonaws.com") + .put("exchange.s3.max-error-retries", "8") + .put("exchange.s3.deletion.thread-count", "3") + .put("exchange.s3.upload.part-size", "10MB") + .build(); + + ExchangeS3Config expected = new ExchangeS3Config() + .setS3AwsAccessKey("access") + .setS3AwsSecretKey("secret") + .setS3Region("us-west-1") + .setS3Endpoint("https://s3.us-east-1.amazonaws.com") + .setS3MaxErrorRetries(8) + .setDeletionThreadCount(3) + .setS3UploadPartSize(DataSize.of(10, MEGABYTE)); + + assertFullMapping(properties, expected); + } +} diff --git a/plugin/trino-exchange/src/test/java/io/trino/plugin/exchange/s3/TestS3FileSystemExchangeManager.java b/plugin/trino-exchange/src/test/java/io/trino/plugin/exchange/s3/TestS3FileSystemExchangeManager.java new file mode 100644 index 000000000000..c2441ce87b24 --- /dev/null +++ b/plugin/trino-exchange/src/test/java/io/trino/plugin/exchange/s3/TestS3FileSystemExchangeManager.java @@ -0,0 +1,58 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.exchange.s3; + +import com.google.common.collect.ImmutableMap; +import io.trino.plugin.exchange.FileSystemExchangeManagerFactory; +import io.trino.plugin.exchange.containers.MinioStorage; +import io.trino.spi.exchange.ExchangeManager; +import io.trino.testing.AbstractTestExchangeManager; +import org.testng.annotations.AfterClass; + +import java.util.Map; + +import static io.trino.testing.sql.TestTable.randomTableSuffix; + +public class TestS3FileSystemExchangeManager + extends AbstractTestExchangeManager +{ + private MinioStorage minioStorage; + + @Override + protected ExchangeManager createExchangeManager() + { + this.minioStorage = new MinioStorage("test-exchange-spooling-" + randomTableSuffix()); + minioStorage.start(); + + Map exchangeManagerProperties = new ImmutableMap.Builder() + .put("exchange.base-directory", "s3n://" + minioStorage.getBucketName()) + .put("exchange.s3.aws-access-key", MinioStorage.ACCESS_KEY) + .put("exchange.s3.aws-secret-key", MinioStorage.SECRET_KEY) + .put("exchange.s3.region", "us-east-1") + // TODO: enable exchange encryption after https is supported for Trino MinIO + .put("exchange.s3.endpoint", "http://" + minioStorage.getMinio().getMinioApiEndpoint()) + .build(); + return new FileSystemExchangeManagerFactory().create(exchangeManagerProperties); + } + + @AfterClass(alwaysRun = true) + public void destroy() + throws Exception + { + if (minioStorage != null) { + minioStorage.close(); + minioStorage = null; + } + } +} diff --git a/plugin/trino-hive/pom.xml b/plugin/trino-hive/pom.xml index 1623967f0086..79a4dca6ce48 100644 --- a/plugin/trino-hive/pom.xml +++ b/plugin/trino-hive/pom.xml @@ -309,6 +309,19 @@ test + + io.trino + trino-exchange + test + + + + io.trino + trino-exchange + test-jar + test + + io.trino trino-main @@ -411,10 +424,21 @@ **/TestHiveGlueMetastore.java **/TestFullParquetReader.java - **/TestHiveFailureRecovery.java + **/TestHive*FailureRecovery.java + **/TestHiveFaultTolerantExecution*.java + + org.basepom.maven + duplicate-finder-maven-plugin + + + mime.types + about.html + + + @@ -462,7 +486,8 @@ maven-surefire-plugin - **/TestHiveFailureRecovery.java + **/TestHive*FailureRecovery.java + **/TestHiveFaultTolerantExecution*.java diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveConnectorTest.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/AbstractHiveConnectorTest.java similarity index 99% rename from plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveConnectorTest.java rename to plugin/trino-hive/src/test/java/io/trino/plugin/hive/AbstractHiveConnectorTest.java index 94ec14a09414..1db7d10183be 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveConnectorTest.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/AbstractHiveConnectorTest.java @@ -171,24 +171,30 @@ import static org.testng.Assert.fail; import static org.testng.FileAssert.assertFile; -public class TestHiveConnectorTest +public abstract class AbstractHiveConnectorTest extends BaseConnectorTest { private static final DateTimeFormatter TIMESTAMP_FORMATTER = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss.SSSSSSSSS"); private final String catalog; private final Session bucketedSession; + private final Map extraProperties; + private final Map exchangeManagerProperties; - public TestHiveConnectorTest() + protected AbstractHiveConnectorTest(Map extraProperties, Map exchangeManagerProperties) { this.catalog = HIVE_CATALOG; this.bucketedSession = createBucketedSession(Optional.of(new SelectedRole(ROLE, Optional.of("admin")))); + this.extraProperties = ImmutableMap.copyOf(requireNonNull(extraProperties, "extraProperties is null")); + this.exchangeManagerProperties = ImmutableMap.copyOf(requireNonNull(exchangeManagerProperties, "exchangeManagerProperties is null")); } @Override - protected QueryRunner createQueryRunner() + protected final QueryRunner createQueryRunner() throws Exception { DistributedQueryRunner queryRunner = HiveQueryRunner.builder() + .setExtraProperties(extraProperties) + .setExchangeManagerProperties(exchangeManagerProperties) .setHiveProperties(ImmutableMap.of( "hive.allow-register-partition-procedure", "true", // Reduce writer sort buffer size to ensure SortingFileWriter gets used @@ -1747,6 +1753,11 @@ public void testCreateTableWithUnsupportedType() @Test public void testTargetMaxFileSize() + { + testTargetMaxFileSize(3); + } + + protected void testTargetMaxFileSize(int expectedTableWriters) { // We use TEXTFILE in this test because is has a very consistent and predictable size @Language("SQL") String createTableSql = "CREATE TABLE test_max_file_size WITH (format = 'TEXTFILE') AS SELECT * FROM tpch.sf1.lineitem LIMIT 1000000"; @@ -1757,7 +1768,7 @@ public void testTargetMaxFileSize() .setSystemProperty("task_writer_count", "1") .build(); assertUpdate(session, createTableSql, 1000000); - assertThat(computeActual(selectFileInfo).getRowCount()).isEqualTo(3); + assertThat(computeActual(selectFileInfo).getRowCount()).isEqualTo(expectedTableWriters); assertUpdate("DROP TABLE test_max_file_size"); // Write table with small limit and verify we get multiple files per node near the expected size @@ -1770,7 +1781,7 @@ public void testTargetMaxFileSize() assertUpdate(session, createTableSql, 1000000); MaterializedResult result = computeActual(selectFileInfo); - assertThat(result.getRowCount()).isGreaterThan(3); + assertThat(result.getRowCount()).isGreaterThan(expectedTableWriters); for (MaterializedRow row : result) { // allow up to a larger delta due to the very small max size and the relatively large writer chunk size assertThat((Long) row.getField(1)).isLessThan(maxSize.toBytes() * 3); @@ -1781,6 +1792,11 @@ public void testTargetMaxFileSize() @Test public void testTargetMaxFileSizePartitioned() + { + testTargetMaxFileSizePartitioned(3); + } + + protected void testTargetMaxFileSizePartitioned(int expectedTableWriters) { // We use TEXTFILE in this test because is has a very consistent and predictable size @Language("SQL") String createTableSql = "" + @@ -1794,7 +1810,7 @@ public void testTargetMaxFileSizePartitioned() .setSystemProperty("task_writer_count", "1") .build(); assertUpdate(session, createTableSql, 1000000); - assertThat(computeActual(selectFileInfo).getRowCount()).isEqualTo(9); + assertThat(computeActual(selectFileInfo).getRowCount()).isEqualTo(expectedTableWriters * 3); assertUpdate("DROP TABLE test_max_file_size"); // Write table with small limit and verify we get multiple files per node near the expected size @@ -1807,7 +1823,7 @@ public void testTargetMaxFileSizePartitioned() assertUpdate(session, createTableSql, 1000000); MaterializedResult result = computeActual(selectFileInfo); - assertThat(result.getRowCount()).isGreaterThan(9); + assertThat(result.getRowCount()).isGreaterThan(expectedTableWriters * 3); for (MaterializedRow row : result) { // allow up to a larger delta due to the very small max size and the relatively large writer chunk size assertThat((Long) row.getField(1)).isLessThan(maxSize.toBytes() * 3); @@ -3639,7 +3655,7 @@ public void testScaleWriters() testWithAllStorageFormats(this::testMultipleWriters); } - private void testSingleWriter(Session session, HiveStorageFormat storageFormat) + protected void testSingleWriter(Session session, HiveStorageFormat storageFormat) { try { // small table that will only have one writer @@ -8363,7 +8379,7 @@ private static class RollbackException { } - private void testWithAllStorageFormats(BiConsumer test) + protected void testWithAllStorageFormats(BiConsumer test) { for (TestingHiveStorageFormat storageFormat : getAllTestingHiveStorageFormat()) { testWithStorageFormat(storageFormat, test); diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFailureRecovery.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/AbstractTestHiveFailureRecovery.java similarity index 89% rename from plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFailureRecovery.java rename to plugin/trino-hive/src/test/java/io/trino/plugin/hive/AbstractTestHiveFailureRecovery.java index ce2165575140..84750be62d46 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFailureRecovery.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/AbstractTestHiveFailureRecovery.java @@ -14,6 +14,7 @@ package io.trino.plugin.hive; import io.trino.Session; +import io.trino.operator.RetryPolicy; import io.trino.testing.AbstractTestFailureRecovery; import io.trino.testing.QueryRunner; import io.trino.tpch.TpchTable; @@ -25,17 +26,27 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; -public class TestHiveFailureRecovery +public abstract class AbstractTestHiveFailureRecovery extends AbstractTestFailureRecovery { + protected AbstractTestHiveFailureRecovery(RetryPolicy retryPolicy, Map exchangeManagerProperties) + { + super(retryPolicy, exchangeManagerProperties); + } + @Override - protected QueryRunner createQueryRunner(List> requiredTpchTables, Map configProperties, Map coordinatorProperties) + protected final QueryRunner createQueryRunner( + List> requiredTpchTables, + Map configProperties, + Map coordinatorProperties, + Map exchangeManagerProperties) throws Exception { return HiveQueryRunner.builder() .setInitialTables(requiredTpchTables) .setCoordinatorProperties(coordinatorProperties) .setExtraProperties(configProperties) + .setExchangeManagerProperties(exchangeManagerProperties) .build(); } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/HiveQueryRunner.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/HiveQueryRunner.java index 648a8bfbf69f..7a3767123f1e 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/HiveQueryRunner.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/HiveQueryRunner.java @@ -20,6 +20,7 @@ import io.airlift.log.Logging; import io.trino.Session; import io.trino.metadata.QualifiedObjectName; +import io.trino.plugin.exchange.FileSystemExchangePlugin; import io.trino.plugin.hive.authentication.HiveIdentity; import io.trino.plugin.hive.metastore.Database; import io.trino.plugin.hive.metastore.HiveMetastore; @@ -90,6 +91,7 @@ public static class Builder> { private boolean skipTimezoneSetup; private ImmutableMap.Builder hiveProperties = ImmutableMap.builder(); + private Map exchangeManagerProperties = ImmutableMap.of(); private List> initialTables = ImmutableList.of(); private Optional initialSchemasLocationBase = Optional.empty(); private Function metastore = queryRunner -> { @@ -128,6 +130,12 @@ public SELF addHiveProperty(String key, String value) return self(); } + public SELF setExchangeManagerProperties(Map exchangeManagerProperties) + { + this.exchangeManagerProperties = ImmutableMap.copyOf(requireNonNull(exchangeManagerProperties, "exchangeManagerProperties is null")); + return self(); + } + public SELF setInitialTables(Iterable> initialTables) { this.initialTables = ImmutableList.copyOf(requireNonNull(initialTables, "initialTables is null")); @@ -167,6 +175,11 @@ public DistributedQueryRunner build() HiveMetastore metastore = this.metastore.apply(queryRunner); queryRunner.installPlugin(new TestingHivePlugin(metastore, module)); + if (!exchangeManagerProperties.isEmpty()) { + queryRunner.installPlugin(new FileSystemExchangePlugin()); + queryRunner.loadExchangeManager("filesystem", exchangeManagerProperties); + } + Map hiveProperties = new HashMap<>(); if (!skipTimezoneSetup) { assertEquals(DateTimeZone.getDefault(), TIME_ZONE, "Timezone not configured correctly. Add -Duser.timezone=America/Bahia_Banderas to your JVM arguments"); diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFaultTolerantExecutionAggregationsFile.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFaultTolerantExecutionAggregationsFile.java new file mode 100644 index 000000000000..ac47c38e1d36 --- /dev/null +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFaultTolerantExecutionAggregationsFile.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive; + +import io.trino.testing.AbstractTestFaultTolerantExecutionAggregations; +import io.trino.testing.BaseFaultTolerantExecutionConnectorTest; +import io.trino.testing.QueryRunner; + +import java.util.Map; + +import static io.trino.tpch.TpchTable.getTables; + +public class TestHiveFaultTolerantExecutionAggregationsFile + extends AbstractTestFaultTolerantExecutionAggregations +{ + @Override + protected QueryRunner createQueryRunner(Map extraProperties) + throws Exception + { + return HiveQueryRunner.builder() + .setExtraProperties(extraProperties) + .setExchangeManagerProperties(BaseFaultTolerantExecutionConnectorTest.getExchangeManagerPropertiesFile()) + .setInitialTables(getTables()) + .build(); + } +} diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFaultTolerantExecutionAggregationsMinio.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFaultTolerantExecutionAggregationsMinio.java new file mode 100644 index 000000000000..623643dc0ba8 --- /dev/null +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFaultTolerantExecutionAggregationsMinio.java @@ -0,0 +1,64 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive; + +import com.google.common.collect.ImmutableMap; +import io.trino.plugin.exchange.containers.MinioStorage; +import io.trino.testing.AbstractTestFaultTolerantExecutionAggregations; +import io.trino.testing.QueryRunner; +import org.testng.annotations.AfterClass; + +import java.util.Map; + +import static io.trino.testing.sql.TestTable.randomTableSuffix; +import static io.trino.tpch.TpchTable.getTables; + +public class TestHiveFaultTolerantExecutionAggregationsMinio + extends AbstractTestFaultTolerantExecutionAggregations +{ + private MinioStorage minioStorage; + + @Override + protected QueryRunner createQueryRunner(Map extraProperties) + throws Exception + { + this.minioStorage = new MinioStorage("test-exchange-spooling-" + randomTableSuffix()); + minioStorage.start(); + + Map exchangeManagerProperties = new ImmutableMap.Builder() + .put("exchange.base-directory", "s3n://" + minioStorage.getBucketName()) + .put("exchange.s3.aws-access-key", MinioStorage.ACCESS_KEY) + .put("exchange.s3.aws-secret-key", MinioStorage.SECRET_KEY) + .put("exchange.s3.region", "us-east-1") + // TODO: enable exchange encryption after https is supported for Trino MinIO + .put("exchange.s3.endpoint", "http://" + minioStorage.getMinio().getMinioApiEndpoint()) + .build(); + + return HiveQueryRunner.builder() + .setExtraProperties(extraProperties) + .setExchangeManagerProperties(exchangeManagerProperties) + .setInitialTables(getTables()) + .build(); + } + + @AfterClass(alwaysRun = true) + public void destroy() + throws Exception + { + if (minioStorage != null) { + minioStorage.close(); + minioStorage = null; + } + } +} diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFaultTolerantExecutionConnectorTest.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFaultTolerantExecutionConnectorTest.java new file mode 100644 index 000000000000..d49f6454dfdd --- /dev/null +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFaultTolerantExecutionConnectorTest.java @@ -0,0 +1,49 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive; + +import io.trino.testing.BaseFaultTolerantExecutionConnectorTest; + +public class TestHiveFaultTolerantExecutionConnectorTest + extends AbstractHiveConnectorTest +{ + public TestHiveFaultTolerantExecutionConnectorTest() + { + super(BaseFaultTolerantExecutionConnectorTest.getExtraProperties(), BaseFaultTolerantExecutionConnectorTest.getExchangeManagerPropertiesFile()); + } + + @Override + public void testGroupedExecution() + { + // grouped execution is not supported (and not needed) with batch execution enabled + } + + @Override + public void testScaleWriters() + { + testWithAllStorageFormats(this::testSingleWriter); + } + + @Override + public void testTargetMaxFileSize() + { + testTargetMaxFileSize(1); + } + + @Override + public void testTargetMaxFileSizePartitioned() + { + testTargetMaxFileSizePartitioned(1); + } +} diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFaultTolerantExecutionJoinQueriesFile.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFaultTolerantExecutionJoinQueriesFile.java new file mode 100644 index 000000000000..1436d444e92a --- /dev/null +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFaultTolerantExecutionJoinQueriesFile.java @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive; + +import io.trino.testing.AbstractTestFaultTolerantExecutionJoinQueries; +import io.trino.testing.BaseFaultTolerantExecutionConnectorTest; +import io.trino.testing.QueryRunner; +import org.testng.annotations.Test; + +import java.util.Map; + +import static io.trino.tpch.TpchTable.getTables; + +public class TestHiveFaultTolerantExecutionJoinQueriesFile + extends AbstractTestFaultTolerantExecutionJoinQueries +{ + @Override + protected QueryRunner createQueryRunner(Map extraProperties) + throws Exception + { + return HiveQueryRunner.builder() + .setExtraProperties(extraProperties) + .setExchangeManagerProperties(BaseFaultTolerantExecutionConnectorTest.getExchangeManagerPropertiesFile()) + .setInitialTables(getTables()) + .build(); + } + + @Override + @Test(enabled = false) + public void testOutputDuplicatesInsensitiveJoin() + { + // flaky + } +} diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFaultTolerantExecutionJoinQueriesMinio.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFaultTolerantExecutionJoinQueriesMinio.java new file mode 100644 index 000000000000..facd1bb9c6ca --- /dev/null +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFaultTolerantExecutionJoinQueriesMinio.java @@ -0,0 +1,64 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive; + +import com.google.common.collect.ImmutableMap; +import io.trino.plugin.exchange.containers.MinioStorage; +import io.trino.testing.AbstractTestFaultTolerantExecutionJoinQueries; +import io.trino.testing.QueryRunner; +import org.testng.annotations.AfterClass; + +import java.util.Map; + +import static io.trino.testing.sql.TestTable.randomTableSuffix; +import static io.trino.tpch.TpchTable.getTables; + +public class TestHiveFaultTolerantExecutionJoinQueriesMinio + extends AbstractTestFaultTolerantExecutionJoinQueries +{ + private MinioStorage minioStorage; + + @Override + protected QueryRunner createQueryRunner(Map extraProperties) + throws Exception + { + this.minioStorage = new MinioStorage("test-exchange-spooling-" + randomTableSuffix()); + minioStorage.start(); + + Map exchangeManagerProperties = new ImmutableMap.Builder() + .put("exchange.base-directory", "s3n://" + minioStorage.getBucketName()) + .put("exchange.s3.aws-access-key", MinioStorage.ACCESS_KEY) + .put("exchange.s3.aws-secret-key", MinioStorage.SECRET_KEY) + .put("exchange.s3.region", "us-east-1") + // TODO: enable exchange encryption after https is supported for Trino MinIO + .put("exchange.s3.endpoint", "http://" + minioStorage.getMinio().getMinioApiEndpoint()) + .build(); + + return HiveQueryRunner.builder() + .setExtraProperties(extraProperties) + .setExchangeManagerProperties(exchangeManagerProperties) + .setInitialTables(getTables()) + .build(); + } + + @AfterClass(alwaysRun = true) + public void destroy() + throws Exception + { + if (minioStorage != null) { + minioStorage.close(); + minioStorage = null; + } + } +} diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFaultTolerantExecutionOrderByQueriesFile.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFaultTolerantExecutionOrderByQueriesFile.java new file mode 100644 index 000000000000..17c2b8cc6726 --- /dev/null +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFaultTolerantExecutionOrderByQueriesFile.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive; + +import io.trino.testing.AbstractTestFaultTolerantExecutionOrderByQueries; +import io.trino.testing.BaseFaultTolerantExecutionConnectorTest; +import io.trino.testing.QueryRunner; + +import java.util.Map; + +import static io.trino.tpch.TpchTable.getTables; + +public class TestHiveFaultTolerantExecutionOrderByQueriesFile + extends AbstractTestFaultTolerantExecutionOrderByQueries +{ + @Override + protected QueryRunner createQueryRunner(Map extraProperties) + throws Exception + { + return HiveQueryRunner.builder() + .setExtraProperties(extraProperties) + .setExchangeManagerProperties(BaseFaultTolerantExecutionConnectorTest.getExchangeManagerPropertiesFile()) + .setInitialTables(getTables()) + .build(); + } +} diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFaultTolerantExecutionOrderByQueriesMinio.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFaultTolerantExecutionOrderByQueriesMinio.java new file mode 100644 index 000000000000..f25daf13fa81 --- /dev/null +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFaultTolerantExecutionOrderByQueriesMinio.java @@ -0,0 +1,64 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive; + +import com.google.common.collect.ImmutableMap; +import io.trino.plugin.exchange.containers.MinioStorage; +import io.trino.testing.AbstractTestFaultTolerantExecutionOrderByQueries; +import io.trino.testing.QueryRunner; +import org.testng.annotations.AfterClass; + +import java.util.Map; + +import static io.trino.testing.sql.TestTable.randomTableSuffix; +import static io.trino.tpch.TpchTable.getTables; + +public class TestHiveFaultTolerantExecutionOrderByQueriesMinio + extends AbstractTestFaultTolerantExecutionOrderByQueries +{ + private MinioStorage minioStorage; + + @Override + protected QueryRunner createQueryRunner(Map extraProperties) + throws Exception + { + this.minioStorage = new MinioStorage("test-exchange-spooling-" + randomTableSuffix()); + minioStorage.start(); + + Map exchangeManagerProperties = new ImmutableMap.Builder() + .put("exchange.base-directory", "s3n://" + minioStorage.getBucketName()) + .put("exchange.s3.aws-access-key", MinioStorage.ACCESS_KEY) + .put("exchange.s3.aws-secret-key", MinioStorage.SECRET_KEY) + .put("exchange.s3.region", "us-east-1") + // TODO: enable exchange encryption after https is supported for Trino MinIO + .put("exchange.s3.endpoint", "http://" + minioStorage.getMinio().getMinioApiEndpoint()) + .build(); + + return HiveQueryRunner.builder() + .setExtraProperties(extraProperties) + .setExchangeManagerProperties(exchangeManagerProperties) + .setInitialTables(getTables()) + .build(); + } + + @AfterClass(alwaysRun = true) + public void destroy() + throws Exception + { + if (minioStorage != null) { + minioStorage.close(); + minioStorage = null; + } + } +} diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFaultTolerantExecutionWindowQueriesFile.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFaultTolerantExecutionWindowQueriesFile.java new file mode 100644 index 000000000000..0f2dea491741 --- /dev/null +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFaultTolerantExecutionWindowQueriesFile.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive; + +import io.trino.testing.AbstractTestFaultTolerantExecutionWindowQueries; +import io.trino.testing.BaseFaultTolerantExecutionConnectorTest; +import io.trino.testing.QueryRunner; + +import java.util.Map; + +import static io.trino.tpch.TpchTable.getTables; + +public class TestHiveFaultTolerantExecutionWindowQueriesFile + extends AbstractTestFaultTolerantExecutionWindowQueries +{ + @Override + protected QueryRunner createQueryRunner(Map extraProperties) + throws Exception + { + return HiveQueryRunner.builder() + .setExtraProperties(extraProperties) + .setExchangeManagerProperties(BaseFaultTolerantExecutionConnectorTest.getExchangeManagerPropertiesFile()) + .setInitialTables(getTables()) + .build(); + } +} diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFaultTolerantExecutionWindowQueriesMinio.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFaultTolerantExecutionWindowQueriesMinio.java new file mode 100644 index 000000000000..64a82fcbf8a2 --- /dev/null +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFaultTolerantExecutionWindowQueriesMinio.java @@ -0,0 +1,64 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive; + +import com.google.common.collect.ImmutableMap; +import io.trino.plugin.exchange.containers.MinioStorage; +import io.trino.testing.AbstractTestFaultTolerantExecutionWindowQueries; +import io.trino.testing.QueryRunner; +import org.testng.annotations.AfterClass; + +import java.util.Map; + +import static io.trino.testing.sql.TestTable.randomTableSuffix; +import static io.trino.tpch.TpchTable.getTables; + +public class TestHiveFaultTolerantExecutionWindowQueriesMinio + extends AbstractTestFaultTolerantExecutionWindowQueries +{ + private MinioStorage minioStorage; + + @Override + protected QueryRunner createQueryRunner(Map extraProperties) + throws Exception + { + this.minioStorage = new MinioStorage("test-exchange-spooling-" + randomTableSuffix()); + minioStorage.start(); + + Map exchangeManagerProperties = new ImmutableMap.Builder() + .put("exchange.base-directory", "s3n://" + minioStorage.getBucketName()) + .put("exchange.s3.aws-access-key", MinioStorage.ACCESS_KEY) + .put("exchange.s3.aws-secret-key", MinioStorage.SECRET_KEY) + .put("exchange.s3.region", "us-east-1") + // TODO: enable exchange encryption after https is supported for Trino MinIO + .put("exchange.s3.endpoint", "http://" + minioStorage.getMinio().getMinioApiEndpoint()) + .build(); + + return HiveQueryRunner.builder() + .setExtraProperties(extraProperties) + .setExchangeManagerProperties(exchangeManagerProperties) + .setInitialTables(getTables()) + .build(); + } + + @AfterClass(alwaysRun = true) + public void destroy() + throws Exception + { + if (minioStorage != null) { + minioStorage.close(); + minioStorage = null; + } + } +} diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveQueryLevelFailureRecovery.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveQueryLevelFailureRecovery.java new file mode 100644 index 000000000000..f19535d19ede --- /dev/null +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveQueryLevelFailureRecovery.java @@ -0,0 +1,26 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive; + +import com.google.common.collect.ImmutableMap; +import io.trino.operator.RetryPolicy; + +public class TestHiveQueryLevelFailureRecovery + extends AbstractTestHiveFailureRecovery +{ + protected TestHiveQueryLevelFailureRecovery() + { + super(RetryPolicy.QUERY, ImmutableMap.of()); + } +} diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveTaskLevelFailureRecovery.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveTaskLevelFailureRecovery.java new file mode 100644 index 000000000000..4bffdb804146 --- /dev/null +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveTaskLevelFailureRecovery.java @@ -0,0 +1,28 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive; + +import com.google.common.collect.ImmutableMap; +import io.trino.operator.RetryPolicy; + +public class TestHiveTaskLevelFailureRecovery + extends AbstractTestHiveFailureRecovery +{ + protected TestHiveTaskLevelFailureRecovery() + { + super(RetryPolicy.TASK, ImmutableMap.of( + "exchange.base-directory", System.getProperty("java.io.tmpdir") + "/trino-local-file-system-exchange-manager", + "exchange.encryption-enabled", "true")); + } +} diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestPipelinedExecutionHiveConnectorTest.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestPipelinedExecutionHiveConnectorTest.java new file mode 100644 index 000000000000..cbeb6443aab5 --- /dev/null +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestPipelinedExecutionHiveConnectorTest.java @@ -0,0 +1,25 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive; + +import com.google.common.collect.ImmutableMap; + +public class TestPipelinedExecutionHiveConnectorTest + extends AbstractHiveConnectorTest +{ + public TestPipelinedExecutionHiveConnectorTest() + { + super(ImmutableMap.of(), ImmutableMap.of()); + } +} diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/containers/HiveMinioDataLake.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/containers/HiveMinioDataLake.java index d71cb7a1f04f..c1fbd60a7e93 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/containers/HiveMinioDataLake.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/containers/HiveMinioDataLake.java @@ -19,6 +19,7 @@ import com.amazonaws.services.s3.AmazonS3; import com.amazonaws.services.s3.AmazonS3ClientBuilder; import com.google.common.collect.ImmutableMap; +import io.trino.testing.containers.Minio; import io.trino.util.AutoCloseableCloser; import org.testcontainers.containers.Network; diff --git a/plugin/trino-thrift/src/test/java/io/trino/plugin/thrift/integration/ThriftQueryRunner.java b/plugin/trino-thrift/src/test/java/io/trino/plugin/thrift/integration/ThriftQueryRunner.java index 19d58053970b..895e00e6b00a 100644 --- a/plugin/trino-thrift/src/test/java/io/trino/plugin/thrift/integration/ThriftQueryRunner.java +++ b/plugin/trino-thrift/src/test/java/io/trino/plugin/thrift/integration/ThriftQueryRunner.java @@ -315,5 +315,11 @@ public void injectTaskFailure( { source.injectTaskFailure(traceToken, stageId, partitionId, attemptId, injectionType, errorType); } + + @Override + public void loadExchangeManager(String name, Map properties) + { + source.loadExchangeManager(name, properties); + } } } diff --git a/pom.xml b/pom.xml index bf8aabcb456b..5ceb6743f974 100644 --- a/pom.xml +++ b/pom.xml @@ -118,6 +118,7 @@ plugin/trino-druid plugin/trino-elasticsearch plugin/trino-example-http + plugin/trino-exchange plugin/trino-geospatial plugin/trino-google-sheets plugin/trino-hive @@ -243,6 +244,19 @@ ${project.version} + + io.trino + trino-exchange + ${project.version} + + + + io.trino + trino-exchange + test-jar + ${project.version} + + io.trino trino-geospatial diff --git a/testing/trino-server-dev/etc/config.properties b/testing/trino-server-dev/etc/config.properties index e1e4530f7421..f8f71a6ca7b7 100644 --- a/testing/trino-server-dev/etc/config.properties +++ b/testing/trino-server-dev/etc/config.properties @@ -27,6 +27,14 @@ scheduler.http-client.idle-timeout=1m query.client.timeout=5m query.min-expire-age=30m +retry-policy=TASK +enable-dynamic-filtering=false +distributed-sort=false + +query.initial-hash-partitions=5 +fault-tolerant-execution-target-task-input-size=10MB +fault-tolerant-execution-target-task-split-count=4 + plugin.bundles=\ ../../plugin/trino-resource-group-managers/pom.xml,\ ../../plugin/trino-password-authenticators/pom.xml, \ @@ -50,6 +58,7 @@ plugin.bundles=\ ../../plugin/trino-google-sheets/pom.xml, \ ../../plugin/trino-druid/pom.xml, \ ../../plugin/trino-geospatial/pom.xml, \ - ../../plugin/trino-http-event-listener/pom.xml + ../../plugin/trino-http-event-listener/pom.xml, \ + ../../plugin/trino-exchange/pom.xml node-scheduler.include-coordinator=true diff --git a/testing/trino-server-dev/etc/exchange-manager.properties b/testing/trino-server-dev/etc/exchange-manager.properties new file mode 100644 index 000000000000..31465c95d60b --- /dev/null +++ b/testing/trino-server-dev/etc/exchange-manager.properties @@ -0,0 +1,3 @@ +exchange-manager.name=filesystem +exchange.base-directory=s3n://zebingl-exchange-spooling-test +exchange.s3.region=us-west-1 diff --git a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestExchangeManager.java b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestExchangeManager.java new file mode 100644 index 000000000000..67ff6a415306 --- /dev/null +++ b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestExchangeManager.java @@ -0,0 +1,192 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.testing; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.Multimap; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import io.trino.spi.QueryId; +import io.trino.spi.exchange.Exchange; +import io.trino.spi.exchange.ExchangeContext; +import io.trino.spi.exchange.ExchangeManager; +import io.trino.spi.exchange.ExchangeSink; +import io.trino.spi.exchange.ExchangeSinkHandle; +import io.trino.spi.exchange.ExchangeSinkInstanceHandle; +import io.trino.spi.exchange.ExchangeSource; +import io.trino.spi.exchange.ExchangeSourceHandle; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.function.Function; + +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static org.assertj.core.api.Assertions.assertThat; +import static org.testng.Assert.assertTrue; + +public abstract class AbstractTestExchangeManager +{ + private ExchangeManager exchangeManager; + + @BeforeClass + public void init() + throws Exception + { + exchangeManager = createExchangeManager(); + } + + @AfterClass(alwaysRun = true) + public void destroy() + throws Exception + { + if (exchangeManager != null) { + exchangeManager = null; + } + } + + protected abstract ExchangeManager createExchangeManager(); + + @Test + public void testHappyPath() + throws Exception + { + Exchange exchange = exchangeManager.create(new ExchangeContext(new QueryId("query"), 0), 2); + ExchangeSinkHandle sinkHandle0 = exchange.addSink(0); + ExchangeSinkHandle sinkHandle1 = exchange.addSink(1); + ExchangeSinkHandle sinkHandle2 = exchange.addSink(2); + exchange.noMoreSinks(); + + ExchangeSinkInstanceHandle sinkHandle = exchange.instantiateSink(sinkHandle0, 0); + writeData( + sinkHandle, + ImmutableListMultimap.of( + 0, "0-0-0", + 1, "0-1-0", + 0, "0-0-1", + 1, "0-1-1"), + true); + exchange.sinkFinished(sinkHandle); + sinkHandle = exchange.instantiateSink(sinkHandle0, 1); + writeData( + sinkHandle, + ImmutableListMultimap.of( + 0, "0-0-0", + 1, "0-1-0", + 0, "0-0-1", + 1, "0-1-1"), + true); + exchange.sinkFinished(sinkHandle); + sinkHandle = exchange.instantiateSink(sinkHandle0, 2); + writeData( + sinkHandle, + ImmutableListMultimap.of( + 0, "failed", + 1, "another failed"), + false); + exchange.sinkFinished(sinkHandle); + + sinkHandle = exchange.instantiateSink(sinkHandle1, 0); + writeData( + sinkHandle, + ImmutableListMultimap.of( + 0, "1-0-0", + 1, "1-1-0", + 0, "1-0-1", + 1, "1-1-1"), + true); + exchange.sinkFinished(sinkHandle); + sinkHandle = exchange.instantiateSink(sinkHandle1, 1); + writeData( + sinkHandle, + ImmutableListMultimap.of( + 0, "1-0-0", + 1, "1-1-0", + 0, "1-0-1", + 1, "1-1-1"), + true); + exchange.sinkFinished(sinkHandle); + sinkHandle = exchange.instantiateSink(sinkHandle1, 2); + writeData( + sinkHandle, + ImmutableListMultimap.of( + 0, "more failed", + 1, "another failed"), + false); + exchange.sinkFinished(sinkHandle); + + sinkHandle = exchange.instantiateSink(sinkHandle2, 2); + writeData( + sinkHandle, + ImmutableListMultimap.of( + 0, "2-0-0", + 1, "2-1-0"), + true); + exchange.sinkFinished(sinkHandle); + + CompletableFuture> inputPartitionHandlesFuture = exchange.getSourceHandles(); + assertTrue(inputPartitionHandlesFuture.isDone()); + + List partitionHandles = inputPartitionHandlesFuture.get(); + assertThat(partitionHandles).hasSize(2); + + Map partitions = partitionHandles.stream() + .collect(toImmutableMap(ExchangeSourceHandle::getPartitionId, Function.identity())); + + assertThat(readData(partitions.get(0))) + .containsExactlyInAnyOrder("0-0-0", "0-0-1", "1-0-0", "1-0-1", "2-0-0"); + + assertThat(readData(partitions.get(1))) + .containsExactlyInAnyOrder("0-1-0", "0-1-1", "1-1-0", "1-1-1", "2-1-0"); + + exchange.close(); + } + + private void writeData(ExchangeSinkInstanceHandle handle, Multimap data, boolean finish) + { + ExchangeSink sink = exchangeManager.createSink(handle); + data.forEach((key, value) -> { + sink.add(key, Slices.utf8Slice(value)); + }); + if (finish) { + sink.finish(); + } + else { + sink.abort(); + } + } + + private List readData(ExchangeSourceHandle handle) + { + return readData(ImmutableList.of(handle)); + } + + private List readData(List handles) + { + ImmutableList.Builder result = ImmutableList.builder(); + try (ExchangeSource source = exchangeManager.createSource(handles)) { + while (!source.isFinished()) { + Slice data = source.read(); + if (data != null) { + result.add(data.toStringUtf8()); + } + } + } + return result.build(); + } +} diff --git a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestFailureRecovery.java b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestFailureRecovery.java index 7d5d24c35c61..21e339619856 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestFailureRecovery.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestFailureRecovery.java @@ -21,6 +21,7 @@ import io.trino.client.StageStats; import io.trino.client.StatementStats; import io.trino.execution.FailureInjector.InjectedFailureType; +import io.trino.operator.RetryPolicy; import io.trino.spi.ErrorType; import io.trino.tpch.TpchTable; import org.assertj.core.api.AbstractThrowableAssert; @@ -64,6 +65,15 @@ public abstract class AbstractTestFailureRecovery private static final Duration MAX_ERROR_DURATION = new Duration(10, SECONDS); private static final Duration REQUEST_TIMEOUT = new Duration(10, SECONDS); + private final RetryPolicy retryPolicy; + private final Map exchangeManagerProperties; + + protected AbstractTestFailureRecovery(RetryPolicy retryPolicy, Map exchangeManagerProperties) + { + this.retryPolicy = requireNonNull(retryPolicy, "retryPolicy is null"); + this.exchangeManagerProperties = requireNonNull(exchangeManagerProperties, "exchangeManagerProperties is null"); + } + @Override protected final QueryRunner createQueryRunner() throws Exception @@ -73,21 +83,27 @@ protected final QueryRunner createQueryRunner() ImmutableMap.builder() .put("query.remote-task.max-error-duration", MAX_ERROR_DURATION.toString()) .put("exchange.max-error-duration", MAX_ERROR_DURATION.toString()) - .put("retry-policy", "QUERY") + .put("retry-policy", retryPolicy.toString()) .put("retry-initial-delay", "0s") .put("retry-attempts", "1") .put("failure-injection.request-timeout", new Duration(REQUEST_TIMEOUT.toMillis() * 2, MILLISECONDS).toString()) .put("exchange.http-client.idle-timeout", REQUEST_TIMEOUT.toString()) + .put("query.initial-hash-partitions", "5") // TODO: re-enable once failure recover supported for this functionality .put("enable-dynamic-filtering", "false") .put("distributed-sort", "false") .build(), ImmutableMap.builder() .put("scheduler.http-client.idle-timeout", REQUEST_TIMEOUT.toString()) - .build()); + .build(), + exchangeManagerProperties); } - protected abstract QueryRunner createQueryRunner(List> requiredTpchTables, Map configProperties, Map coordinatorProperties) + protected abstract QueryRunner createQueryRunner( + List> requiredTpchTables, + Map configProperties, + Map coordinatorProperties, + Map exchangeManagerProperties) throws Exception; @Test(invocationCount = INVOCATION_COUNT) diff --git a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestFaultTolerantExecutionAggregations.java b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestFaultTolerantExecutionAggregations.java new file mode 100644 index 000000000000..646bf0c127fc --- /dev/null +++ b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestFaultTolerantExecutionAggregations.java @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.testing; + +import java.util.Map; + +public abstract class AbstractTestFaultTolerantExecutionAggregations + extends AbstractTestAggregations +{ + @Override + protected final QueryRunner createQueryRunner() + throws Exception + { + return createQueryRunner(BaseFaultTolerantExecutionConnectorTest.getExtraProperties()); + } + + protected abstract QueryRunner createQueryRunner(Map extraProperties) + throws Exception; +} diff --git a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestFaultTolerantExecutionJoinQueries.java b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestFaultTolerantExecutionJoinQueries.java new file mode 100644 index 000000000000..265d093c5ef0 --- /dev/null +++ b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestFaultTolerantExecutionJoinQueries.java @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.testing; + +import java.util.Map; + +public abstract class AbstractTestFaultTolerantExecutionJoinQueries + extends AbstractTestJoinQueries +{ + @Override + protected final QueryRunner createQueryRunner() + throws Exception + { + return createQueryRunner(BaseFaultTolerantExecutionConnectorTest.getExtraProperties()); + } + + protected abstract QueryRunner createQueryRunner(Map extraProperties) + throws Exception; +} diff --git a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestFaultTolerantExecutionOrderByQueries.java b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestFaultTolerantExecutionOrderByQueries.java new file mode 100644 index 000000000000..82f847d4aa24 --- /dev/null +++ b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestFaultTolerantExecutionOrderByQueries.java @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.testing; + +import java.util.Map; + +public abstract class AbstractTestFaultTolerantExecutionOrderByQueries + extends AbstractTestOrderByQueries +{ + @Override + protected final QueryRunner createQueryRunner() + throws Exception + { + return createQueryRunner(BaseFaultTolerantExecutionConnectorTest.getExtraProperties()); + } + + protected abstract QueryRunner createQueryRunner(Map extraProperties) + throws Exception; +} diff --git a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestFaultTolerantExecutionWindowQueries.java b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestFaultTolerantExecutionWindowQueries.java new file mode 100644 index 000000000000..90aa5d5d5d84 --- /dev/null +++ b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestFaultTolerantExecutionWindowQueries.java @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.testing; + +import java.util.Map; + +public abstract class AbstractTestFaultTolerantExecutionWindowQueries + extends AbstractTestWindowQueries +{ + @Override + protected final QueryRunner createQueryRunner() + throws Exception + { + return createQueryRunner(BaseFaultTolerantExecutionConnectorTest.getExtraProperties()); + } + + protected abstract QueryRunner createQueryRunner(Map extraProperties) + throws Exception; +} diff --git a/testing/trino-testing/src/main/java/io/trino/testing/BaseFaultTolerantExecutionConnectorTest.java b/testing/trino-testing/src/main/java/io/trino/testing/BaseFaultTolerantExecutionConnectorTest.java new file mode 100644 index 000000000000..959daacc3c88 --- /dev/null +++ b/testing/trino-testing/src/main/java/io/trino/testing/BaseFaultTolerantExecutionConnectorTest.java @@ -0,0 +1,53 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.testing; + +import com.google.common.collect.ImmutableMap; + +import java.util.Map; + +public abstract class BaseFaultTolerantExecutionConnectorTest + extends BaseConnectorTest +{ + @Override + protected final QueryRunner createQueryRunner() + throws Exception + { + return createQueryRunner(BaseFaultTolerantExecutionConnectorTest.getExtraProperties()); + } + + protected abstract QueryRunner createQueryRunner(Map extraProperties) + throws Exception; + + public static Map getExtraProperties() + { + return ImmutableMap.builder() + .put("retry-policy", "TASK") + .put("query.initial-hash-partitions", "5") + .put("fault-tolerant-execution-target-task-input-size", "10MB") + .put("fault-tolerant-execution-target-task-split-count", "4") + // TODO: re-enable once failure recover supported for this functionality + .put("enable-dynamic-filtering", "false") + .put("distributed-sort", "false") + .build(); + } + + public static Map getExchangeManagerPropertiesFile() + { + return ImmutableMap.builder() + .put("exchange.base-directory", System.getProperty("java.io.tmpdir") + "/trino-local-file-system-exchange-manager") + .put("exchange.encryption-enabled", "true") + .build(); + } +} diff --git a/testing/trino-testing/src/main/java/io/trino/testing/DistributedQueryRunner.java b/testing/trino-testing/src/main/java/io/trino/testing/DistributedQueryRunner.java index 5949b2458563..5d2a184564f0 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/DistributedQueryRunner.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/DistributedQueryRunner.java @@ -557,6 +557,14 @@ public void injectTaskFailure( } } + @Override + public void loadExchangeManager(String name, Map properties) + { + for (TestingTrinoServer server : servers) { + server.loadExchangeManager(name, properties); + } + } + @Override public final void close() { diff --git a/testing/trino-testing/src/main/java/io/trino/testing/StandaloneQueryRunner.java b/testing/trino-testing/src/main/java/io/trino/testing/StandaloneQueryRunner.java index 95f1573830de..9435cc45475c 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/StandaloneQueryRunner.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/StandaloneQueryRunner.java @@ -280,6 +280,12 @@ public void injectTaskFailure( errorType); } + @Override + public void loadExchangeManager(String name, Map properties) + { + server.loadExchangeManager(name, properties); + } + private static TestingTrinoServer createTestingTrinoServer() { return TestingTrinoServer.builder() diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/containers/Minio.java b/testing/trino-testing/src/main/java/io/trino/testing/containers/Minio.java similarity index 97% rename from plugin/trino-hive/src/test/java/io/trino/plugin/hive/containers/Minio.java rename to testing/trino-testing/src/main/java/io/trino/testing/containers/Minio.java index fb6e0d89a8ce..a66a68169c33 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/containers/Minio.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/containers/Minio.java @@ -11,13 +11,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.trino.plugin.hive.containers; +package io.trino.testing.containers; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.net.HostAndPort; import io.airlift.log.Logger; -import io.trino.testing.containers.BaseTestContainer; import org.testcontainers.containers.Network; import java.util.Map;