diff --git a/core/trino-main/src/main/java/io/trino/exchange/LazyExchangeDataSource.java b/core/trino-main/src/main/java/io/trino/exchange/LazyExchangeDataSource.java index 6325ba91afb9..58e697d7bf71 100644 --- a/core/trino-main/src/main/java/io/trino/exchange/LazyExchangeDataSource.java +++ b/core/trino-main/src/main/java/io/trino/exchange/LazyExchangeDataSource.java @@ -116,6 +116,7 @@ public void addInput(ExchangeInput input) return; } ExchangeDataSource dataSource = delegate.get(); + boolean inputAdded = false; if (dataSource == null) { if (input instanceof DirectExchangeInput) { DirectExchangeClient client = directExchangeClientSupplier.get(queryId, exchangeId, systemMemoryContext, taskFailureListener, retryPolicy); @@ -126,7 +127,8 @@ else if (input instanceof SpoolingExchangeInput) { ExchangeManager exchangeManager = exchangeManagerRegistry.getExchangeManager(); List sourceHandles = spoolingExchangeInput.getExchangeSourceHandles(); ExchangeSource exchangeSource = exchangeManager.createSource(sourceHandles); - dataSource = new SpoolingExchangeDataSource(exchangeSource, sourceHandles, systemMemoryContext); + dataSource = new SpoolingExchangeDataSource(exchangeSource, systemMemoryContext); + inputAdded = true; } else { throw new IllegalArgumentException("Unexpected input: " + input); @@ -134,7 +136,9 @@ else if (input instanceof SpoolingExchangeInput) { delegate.set(dataSource); initialized = true; } - dataSource.addInput(input); + if (!inputAdded) { + dataSource.addInput(input); + } } if (initialized) { diff --git a/core/trino-main/src/main/java/io/trino/exchange/SpoolingExchangeDataSource.java b/core/trino-main/src/main/java/io/trino/exchange/SpoolingExchangeDataSource.java index 51b5cc624444..62a44ea83c4f 100644 --- a/core/trino-main/src/main/java/io/trino/exchange/SpoolingExchangeDataSource.java +++ b/core/trino-main/src/main/java/io/trino/exchange/SpoolingExchangeDataSource.java @@ -13,18 +13,13 @@ */ package io.trino.exchange; -import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListenableFuture; import io.airlift.log.Logger; import io.airlift.slice.Slice; import io.trino.memory.context.LocalMemoryContext; import io.trino.operator.OperatorInfo; import io.trino.spi.exchange.ExchangeSource; -import io.trino.spi.exchange.ExchangeSourceHandle; -import java.util.List; - -import static com.google.common.base.Preconditions.checkState; import static com.google.common.util.concurrent.Futures.immediateVoidFuture; import static io.airlift.concurrent.MoreFutures.toListenableFuture; import static java.util.Objects.requireNonNull; @@ -39,18 +34,13 @@ public class SpoolingExchangeDataSource // It doesn't have to be declared as volatile as the nullification of this variable doesn't have to be immediately visible to other threads. // However since close can be called at any moment this variable has to be accessed in a safe way (avoiding "check-then-use"). private ExchangeSource exchangeSource; - private final List exchangeSourceHandles; private final LocalMemoryContext systemMemoryContext; private volatile boolean closed; - public SpoolingExchangeDataSource( - ExchangeSource exchangeSource, - List exchangeSourceHandles, - LocalMemoryContext systemMemoryContext) + public SpoolingExchangeDataSource(ExchangeSource exchangeSource, LocalMemoryContext systemMemoryContext) { // this assignment is expected to be followed by an assignment of a final field to ensure safe publication this.exchangeSource = requireNonNull(exchangeSource, "exchangeSource is null"); - this.exchangeSourceHandles = ImmutableList.copyOf(requireNonNull(exchangeSourceHandles, "exchangeSourceHandles is null")); this.systemMemoryContext = requireNonNull(systemMemoryContext, "systemMemoryContext is null"); } @@ -96,16 +86,7 @@ public ListenableFuture isBlocked() @Override public void addInput(ExchangeInput input) { - SpoolingExchangeInput exchangeInput = (SpoolingExchangeInput) input; - // Only a single input is expected when the spooling exchange is used. - // The engine adds the same input to every instance of the ExchangeOperator. - // Since the ExchangeDataSource is shared between ExchangeOperator instances - // the same input may be delivered multiple times. - checkState( - exchangeInput.getExchangeSourceHandles().equals(exchangeSourceHandles), - "exchange input is expected to contain an identical exchangeSourceHandles list: %s != %s", - exchangeInput.getExchangeSourceHandles(), - exchangeSourceHandles); + throw new UnsupportedOperationException("only a single input is expected"); } @Override diff --git a/core/trino-main/src/main/java/io/trino/execution/DataDefinitionExecution.java b/core/trino-main/src/main/java/io/trino/execution/DataDefinitionExecution.java index 9aacb600f157..521926d8dd6b 100644 --- a/core/trino-main/src/main/java/io/trino/execution/DataDefinitionExecution.java +++ b/core/trino-main/src/main/java/io/trino/execution/DataDefinitionExecution.java @@ -165,7 +165,7 @@ public void onFailure(Throwable throwable) } @Override - public void addOutputInfoListener(Consumer listener) + public void setOutputInfoListener(Consumer listener) { // DDL does not have an output } diff --git a/core/trino-main/src/main/java/io/trino/execution/QueryExecution.java b/core/trino-main/src/main/java/io/trino/execution/QueryExecution.java index f0608f70b7ab..38c5a85358fa 100644 --- a/core/trino-main/src/main/java/io/trino/execution/QueryExecution.java +++ b/core/trino-main/src/main/java/io/trino/execution/QueryExecution.java @@ -28,6 +28,7 @@ import io.trino.sql.planner.Plan; import java.util.List; +import java.util.Queue; import java.util.function.Consumer; import static java.util.Objects.requireNonNull; @@ -41,7 +42,7 @@ public interface QueryExecution void addStateChangeListener(StateChangeListener stateChangeListener); - void addOutputInfoListener(Consumer listener); + void setOutputInfoListener(Consumer listener); void outputTaskFailed(TaskId taskId, Throwable failure); @@ -86,23 +87,23 @@ interface QueryExecutionFactory } /** - * Output schema and buffer URIs for query. The info will always contain column names and types. Buffer locations will always - * contain the full location set, but may be empty. Users of this data should keep a private copy of the seen buffers to - * handle out of order events from the listener. Once noMoreBufferLocations is set the locations will never change, and - * it is guaranteed that all previously sent locations are contained in the buffer locations. + * The info will always contain column names and types. + * The {@code inputsQueue} is shared between {@link QueryOutputInfo} instances. + * It is guaranteed that no new entries will be added to {@code inputsQueue} after {@link QueryOutputInfo} + * with {@link #isNoMoreInputs()} {@code == true} is created. */ class QueryOutputInfo { private final List columnNames; private final List columnTypes; - private final List inputs; + private final Queue inputsQueue; private final boolean noMoreInputs; - public QueryOutputInfo(List columnNames, List columnTypes, List inputs, boolean noMoreInputs) + public QueryOutputInfo(List columnNames, List columnTypes, Queue inputsQueue, boolean noMoreInputs) { this.columnNames = ImmutableList.copyOf(requireNonNull(columnNames, "columnNames is null")); this.columnTypes = ImmutableList.copyOf(requireNonNull(columnTypes, "columnTypes is null")); - this.inputs = ImmutableList.copyOf(requireNonNull(inputs, "inputs is null")); + this.inputsQueue = requireNonNull(inputsQueue, "inputsQueue is null"); this.noMoreInputs = noMoreInputs; } @@ -116,9 +117,15 @@ public List getColumnTypes() return columnTypes; } - public List getInputs() + public void drainInputs(Consumer consumer) { - return inputs; + while (true) { + ExchangeInput input = inputsQueue.poll(); + if (input == null) { + break; + } + consumer.accept(input); + } } public boolean isNoMoreInputs() diff --git a/core/trino-main/src/main/java/io/trino/execution/QueryManager.java b/core/trino-main/src/main/java/io/trino/execution/QueryManager.java index 1543d7c5dc39..c0d3c588e389 100644 --- a/core/trino-main/src/main/java/io/trino/execution/QueryManager.java +++ b/core/trino-main/src/main/java/io/trino/execution/QueryManager.java @@ -33,7 +33,7 @@ public interface QueryManager * * @throws NoSuchElementException if query does not exist */ - void addOutputInfoListener(QueryId queryId, Consumer listener) + void setOutputInfoListener(QueryId queryId, Consumer listener) throws NoSuchElementException; /** diff --git a/core/trino-main/src/main/java/io/trino/execution/QueryStateMachine.java b/core/trino-main/src/main/java/io/trino/execution/QueryStateMachine.java index fdf0e81164c8..7eb106408385 100644 --- a/core/trino-main/src/main/java/io/trino/execution/QueryStateMachine.java +++ b/core/trino-main/src/main/java/io/trino/execution/QueryStateMachine.java @@ -64,8 +64,10 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Queue; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; @@ -284,6 +286,7 @@ static QueryStateMachine beginWithTicker( QUERY_STATE_LOG.debug("Query %s is %s", queryStateMachine.getQueryId(), newState); if (newState.isDone()) { queryStateMachine.getSession().getTransactionId().ifPresent(transactionManager::trySetInactive); + queryStateMachine.getOutputManager().setQueryCompleted(); } }); @@ -711,9 +714,9 @@ private QueryStats getQueryStats(Optional rootStage, List operatorStatsSummary.build()); } - public void addOutputInfoListener(Consumer listener) + public void setOutputInfoListener(Consumer listener) { - outputManager.addOutputInfoListener(listener); + outputManager.setOutputInfoListener(listener); } public void addOutputTaskFailureListener(TaskFailureListener listener) @@ -1282,21 +1285,28 @@ private static QueryStats pruneQueryStats(QueryStats queryStats) ImmutableList.of()); // Remove the operator summaries as OperatorInfo (especially DirectExchangeClientStatus) can hold onto a large amount of memory } + private QueryOutputManager getOutputManager() + { + return outputManager; + } + public static class QueryOutputManager { private final Executor executor; @GuardedBy("this") - private final List> outputInfoListeners = new ArrayList<>(); + private Optional> listener = Optional.empty(); @GuardedBy("this") private List columnNames; @GuardedBy("this") private List columnTypes; @GuardedBy("this") - private final List inputs = new ArrayList<>(); - @GuardedBy("this") private boolean noMoreInputs; + @GuardedBy("this") + private boolean queryCompleted; + + private final Queue inputsQueue = new ConcurrentLinkedQueue<>(); @GuardedBy("this") private final Map outputTaskFailures = new HashMap<>(); @@ -1308,16 +1318,17 @@ public QueryOutputManager(Executor executor) this.executor = requireNonNull(executor, "executor is null"); } - public void addOutputInfoListener(Consumer listener) + public void setOutputInfoListener(Consumer listener) { requireNonNull(listener, "listener is null"); Optional queryOutputInfo; synchronized (this) { - outputInfoListeners.add(listener); + checkState(this.listener.isEmpty(), "listener is already set"); + this.listener = Optional.of(listener); queryOutputInfo = getQueryOutputInfo(); } - queryOutputInfo.ifPresent(info -> executor.execute(() -> listener.accept(info))); + fireStateChangedIfReady(queryOutputInfo, Optional.of(listener)); } public void setColumns(List columnNames, List columnTypes) @@ -1327,16 +1338,16 @@ public void setColumns(List columnNames, List columnTypes) checkArgument(columnNames.size() == columnTypes.size(), "columnNames and columnTypes must be the same size"); Optional queryOutputInfo; - List> outputInfoListeners; + Optional> listener; synchronized (this) { checkState(this.columnNames == null && this.columnTypes == null, "output fields already set"); this.columnNames = ImmutableList.copyOf(columnNames); this.columnTypes = ImmutableList.copyOf(columnTypes); queryOutputInfo = getQueryOutputInfo(); - outputInfoListeners = ImmutableList.copyOf(this.outputInfoListeners); + listener = this.listener; } - queryOutputInfo.ifPresent(info -> fireStateChanged(info, outputInfoListeners)); + fireStateChangedIfReady(queryOutputInfo, listener); } public void updateInputsForQueryResults(List newInputs, boolean noMoreInputs) @@ -1344,16 +1355,28 @@ public void updateInputsForQueryResults(List newInputs, boolean n requireNonNull(newInputs, "newInputs is null"); Optional queryOutputInfo; - List> outputInfoListeners; + Optional> listener; synchronized (this) { - // noMoreInputs can be set more than once - checkState(newInputs.isEmpty() || !this.noMoreInputs, "new inputs added after no more inputs set"); - inputs.addAll(newInputs); - this.noMoreInputs = noMoreInputs; + if (!queryCompleted) { + // noMoreInputs can be set more than once + checkState(newInputs.isEmpty() || !this.noMoreInputs, "new inputs added after no more inputs set"); + inputsQueue.addAll(newInputs); + this.noMoreInputs = noMoreInputs; + } queryOutputInfo = getQueryOutputInfo(); - outputInfoListeners = ImmutableList.copyOf(this.outputInfoListeners); + listener = this.listener; + } + fireStateChangedIfReady(queryOutputInfo, listener); + } + + public synchronized void setQueryCompleted() + { + if (queryCompleted) { + return; } - queryOutputInfo.ifPresent(info -> fireStateChanged(info, outputInfoListeners)); + queryCompleted = true; + inputsQueue.clear(); + noMoreInputs = true; } public void addOutputTaskFailureListener(TaskFailureListener listener) @@ -1387,14 +1410,15 @@ private synchronized Optional getQueryOutputInfo() if (columnNames == null || columnTypes == null) { return Optional.empty(); } - return Optional.of(new QueryOutputInfo(columnNames, columnTypes, inputs, noMoreInputs)); + return Optional.of(new QueryOutputInfo(columnNames, columnTypes, inputsQueue, noMoreInputs)); } - private void fireStateChanged(QueryOutputInfo queryOutputInfo, List> outputInfoListeners) + private void fireStateChangedIfReady(Optional info, Optional> listener) { - for (Consumer outputInfoListener : outputInfoListeners) { - executor.execute(() -> outputInfoListener.accept(queryOutputInfo)); + if (info.isEmpty() || listener.isEmpty()) { + return; } + executor.execute(() -> listener.get().accept(info.get())); } } } diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java b/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java index 1f60d3b5bc82..55c88d341350 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java @@ -598,9 +598,9 @@ public boolean isDone() } @Override - public void addOutputInfoListener(Consumer listener) + public void setOutputInfoListener(Consumer listener) { - stateMachine.addOutputInfoListener(listener); + stateMachine.setOutputInfoListener(listener); } @Override diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlQueryManager.java b/core/trino-main/src/main/java/io/trino/execution/SqlQueryManager.java index f04472c2f06c..248cd6a6e7b2 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlQueryManager.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlQueryManager.java @@ -148,11 +148,11 @@ public List getQueries() } @Override - public void addOutputInfoListener(QueryId queryId, Consumer listener) + public void setOutputInfoListener(QueryId queryId, Consumer listener) { requireNonNull(listener, "listener is null"); - queryTracker.getQuery(queryId).addOutputInfoListener(listener); + queryTracker.getQuery(queryId).setOutputInfoListener(listener); } @Override diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java b/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java index 0ed02836060e..a104d0a6b527 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java @@ -44,16 +44,14 @@ import java.lang.ref.WeakReference; import java.util.ArrayList; -import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Map.Entry; import java.util.Optional; import java.util.OptionalInt; +import java.util.Queue; import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicBoolean; @@ -94,14 +92,9 @@ public class SqlTaskExecution private final SplitMonitor splitMonitor; - private final List> drivers = new CopyOnWriteArrayList<>(); - private final Map driverRunnerFactoriesWithSplitLifeCycle; private final List driverRunnerFactoriesWithTaskLifeCycle; - - // guarded for update only - @GuardedBy("this") - private final ConcurrentMap unpartitionedSplitAssignments = new ConcurrentHashMap<>(); + private final Map driverRunnerFactoriesWithRemoteSource; @GuardedBy("this") private long maxAcknowledgedSplit = Long.MIN_VALUE; @@ -141,17 +134,21 @@ public SqlTaskExecution( Set partitionedSources = ImmutableSet.copyOf(localExecutionPlan.getPartitionedSourceOrder()); ImmutableMap.Builder driverRunnerFactoriesWithSplitLifeCycle = ImmutableMap.builder(); ImmutableList.Builder driverRunnerFactoriesWithTaskLifeCycle = ImmutableList.builder(); + ImmutableMap.Builder driverRunnerFactoriesWithRemoteSource = ImmutableMap.builder(); for (DriverFactory driverFactory : localExecutionPlan.getDriverFactories()) { Optional sourceId = driverFactory.getSourceId(); if (sourceId.isPresent() && partitionedSources.contains(sourceId.get())) { driverRunnerFactoriesWithSplitLifeCycle.put(sourceId.get(), new DriverSplitRunnerFactory(driverFactory, true)); } else { - driverRunnerFactoriesWithTaskLifeCycle.add(new DriverSplitRunnerFactory(driverFactory, false)); + DriverSplitRunnerFactory runnerFactory = new DriverSplitRunnerFactory(driverFactory, false); + sourceId.ifPresent(planNodeId -> driverRunnerFactoriesWithRemoteSource.put(planNodeId, runnerFactory)); + driverRunnerFactoriesWithTaskLifeCycle.add(runnerFactory); } } this.driverRunnerFactoriesWithSplitLifeCycle = driverRunnerFactoriesWithSplitLifeCycle.buildOrThrow(); this.driverRunnerFactoriesWithTaskLifeCycle = driverRunnerFactoriesWithTaskLifeCycle.build(); + this.driverRunnerFactoriesWithRemoteSource = driverRunnerFactoriesWithRemoteSource.buildOrThrow(); this.pendingSplitsByPlanNode = this.driverRunnerFactoriesWithSplitLifeCycle.keySet().stream() .collect(toImmutableMap(identity(), ignore -> new PendingSplitsForPlanNode())); @@ -223,39 +220,20 @@ public void addSplitAssignments(List splitAssignments) try (SetThreadName ignored = new SetThreadName("Task-%s", taskId)) { // update our record of split assignments and schedule drivers for new partitioned splits - Map updatedUnpartitionedSources = updateSplitAssignments(splitAssignments); - - // tell existing drivers about the new splits; it is safe to update drivers - // multiple times and out of order because split assignments contain full record of - // the unpartitioned splits - for (WeakReference driverReference : drivers) { - Driver driver = driverReference.get(); - // the driver can be GCed due to a failure or a limit - if (driver == null) { - // remove the weak reference from the list to avoid a memory leak - // NOTE: this is a concurrent safe operation on a CopyOnWriteArrayList - drivers.remove(driverReference); - continue; - } - Optional sourceId = driver.getSourceId(); - if (sourceId.isEmpty()) { - continue; - } - SplitAssignment splitAssignmentUpdate = updatedUnpartitionedSources.get(sourceId.get()); - if (splitAssignmentUpdate == null) { - continue; - } - driver.updateSplitAssignment(splitAssignmentUpdate); + Set updatedUnpartitionedSources = updateSplitAssignments(splitAssignments); + for (PlanNodeId planNodeId : updatedUnpartitionedSources) { + DriverSplitRunnerFactory factory = driverRunnerFactoriesWithRemoteSource.get(planNodeId); + // schedule splits outside the lock + factory.scheduleSplits(); } - // we may have transitioned to no more splits, so check for completion checkTaskCompletion(); } } - private synchronized Map updateSplitAssignments(List splitAssignments) + private synchronized Set updateSplitAssignments(List splitAssignments) { - Map updatedUnpartitionedSplitAssignments = new HashMap<>(); + ImmutableSet.Builder updatedUnpartitionedSources = ImmutableSet.builder(); // first remove any split that was already acknowledged long currentMaxAcknowledgedSplit = this.maxAcknowledgedSplit; @@ -274,7 +252,10 @@ private synchronized Map updateSplitAssignments(Lis schedulePartitionedSource(assignment); } else { - scheduleUnpartitionedSource(assignment, updatedUnpartitionedSplitAssignments); + // tell existing drivers about the new splits + DriverSplitRunnerFactory factory = driverRunnerFactoriesWithRemoteSource.get(assignment.getPlanNodeId()); + factory.enqueueSplits(assignment.getSplits(), assignment.isNoMoreSplits()); + updatedUnpartitionedSources.add(assignment.getPlanNodeId()); } } @@ -284,7 +265,7 @@ private synchronized Map updateSplitAssignments(Lis .mapToLong(ScheduledSplit::getSequenceId) .max() .orElse(maxAcknowledgedSplit); - return updatedUnpartitionedSplitAssignments; + return updatedUnpartitionedSources.build(); } @GuardedBy("this") @@ -335,25 +316,6 @@ private synchronized void schedulePartitionedSource(SplitAssignment splitAssignm } } - private synchronized void scheduleUnpartitionedSource(SplitAssignment splitAssignmentUpdate, Map updatedUnpartitionedSources) - { - // create new source - SplitAssignment newSplitAssignment; - SplitAssignment currentSplitAssignment = unpartitionedSplitAssignments.get(splitAssignmentUpdate.getPlanNodeId()); - if (currentSplitAssignment == null) { - newSplitAssignment = splitAssignmentUpdate; - } - else { - newSplitAssignment = currentSplitAssignment.update(splitAssignmentUpdate); - } - - // only record new source if something changed - if (newSplitAssignment != currentSplitAssignment) { - unpartitionedSplitAssignments.put(splitAssignmentUpdate.getPlanNodeId(), newSplitAssignment); - updatedUnpartitionedSources.put(splitAssignmentUpdate.getPlanNodeId(), newSplitAssignment); - } - } - private void scheduleDriversForTaskLifeCycle() { // This method is called at the beginning of the task. @@ -436,14 +398,14 @@ private DriverStats getDriverStats() public synchronized Set getNoMoreSplits() { ImmutableSet.Builder noMoreSplits = ImmutableSet.builder(); - for (Entry entry : driverRunnerFactoriesWithSplitLifeCycle.entrySet()) { + for (Map.Entry entry : driverRunnerFactoriesWithSplitLifeCycle.entrySet()) { if (entry.getValue().isNoMoreDriverRunner()) { noMoreSplits.add(entry.getKey()); } } - for (SplitAssignment splitAssignment : unpartitionedSplitAssignments.values()) { - if (splitAssignment.isNoMoreSplits()) { - noMoreSplits.add(splitAssignment.getPlanNodeId()); + for (Map.Entry entry : driverRunnerFactoriesWithRemoteSource.entrySet()) { + if (entry.getValue().isNoMoreSplits()) { + noMoreSplits.add(entry.getKey()); } } return noMoreSplits.build(); @@ -499,7 +461,6 @@ public String toString() return toStringHelper(this) .add("taskId", taskId) .add("remainingDrivers", remainingDrivers.get()) - .add("unpartitionedSplitAssignments", unpartitionedSplitAssignments) .toString(); } @@ -579,6 +540,11 @@ private class DriverSplitRunnerFactory // true if no more DriverSplitRunners will be created private final AtomicBoolean noMoreDriverRunner = new AtomicBoolean(); + private final List> driverReferences = new CopyOnWriteArrayList<>(); + private final Queue queuedSplits = new ConcurrentLinkedQueue<>(); + private final AtomicLong inFlightSplits = new AtomicLong(); + private final AtomicBoolean noMoreSplits = new AtomicBoolean(); + private DriverSplitRunnerFactory(DriverFactory driverFactory, boolean partitioned) { this.driverFactory = driverFactory; @@ -602,30 +568,77 @@ public Driver createDriver(DriverContext driverContext, @Nullable ScheduledSplit { Driver driver = driverFactory.createDriver(driverContext); - // record driver so other threads add unpartitioned sources can see the driver - // NOTE: this MUST be done before reading unpartitionedSources, so we see a consistent view of the unpartitioned sources - drivers.add(new WeakReference<>(driver)); - if (partitionedSplit != null) { // TableScanOperator requires partitioned split to be added before the first call to process driver.updateSplitAssignment(new SplitAssignment(partitionedSplit.getPlanNodeId(), ImmutableSet.of(partitionedSplit), true)); } - // add unpartitioned sources - Optional sourceId = driver.getSourceId(); - if (sourceId.isPresent()) { - SplitAssignment splitAssignment = unpartitionedSplitAssignments.get(sourceId.get()); - if (splitAssignment != null) { - driver.updateSplitAssignment(splitAssignment); - } - } - pendingCreations.decrementAndGet(); closeDriverFactoryIfFullyCreated(); + if (driverFactory.getSourceId().isPresent() && partitionedSplit == null) { + driverReferences.add(new WeakReference<>(driver)); + scheduleSplits(); + } + return driver; } + public void enqueueSplits(Set splits, boolean noMoreSplits) + { + verify(driverFactory.getSourceId().isPresent(), "not a source driver"); + verify(!this.noMoreSplits.get() || splits.isEmpty(), "cannot add splits after noMoreSplits is set"); + queuedSplits.addAll(splits); + verify(!this.noMoreSplits.get() || noMoreSplits, "cannot unset noMoreSplits"); + if (noMoreSplits) { + this.noMoreSplits.set(true); + } + } + + public void scheduleSplits() + { + if (driverReferences.isEmpty()) { + return; + } + + PlanNodeId sourceId = driverFactory.getSourceId().orElseThrow(); + while (!queuedSplits.isEmpty()) { + int activeDriversCount = 0; + for (WeakReference driverReference : driverReferences) { + Driver driver = driverReference.get(); + if (driver == null) { + continue; + } + activeDriversCount++; + inFlightSplits.incrementAndGet(); + ScheduledSplit split = queuedSplits.poll(); + if (split == null) { + inFlightSplits.decrementAndGet(); + break; + } + driver.updateSplitAssignment(new SplitAssignment(sourceId, ImmutableSet.of(split), false)); + inFlightSplits.decrementAndGet(); + } + if (activeDriversCount == 0) { + break; + } + } + + if (noMoreSplits.get() && queuedSplits.isEmpty() && inFlightSplits.get() == 0) { + for (WeakReference driverReference : driverReferences) { + Driver driver = driverReference.get(); + if (driver != null) { + driver.updateSplitAssignment(new SplitAssignment(sourceId, ImmutableSet.of(), true)); + } + } + } + } + + public boolean isNoMoreSplits() + { + return noMoreSplits.get(); + } + public void noMoreDriverRunner() { noMoreDriverRunner.set(true); diff --git a/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClient.java b/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClient.java index 32415aa1d3f2..f5e461cc3e5d 100644 --- a/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClient.java +++ b/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClient.java @@ -45,6 +45,7 @@ import java.util.concurrent.locks.ReadWriteLock; import java.util.concurrent.locks.ReentrantReadWriteLock; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.Sets.newConcurrentHashSet; import static java.util.Objects.requireNonNull; @@ -155,10 +156,7 @@ public synchronized void addLocation(TaskId taskId, URI location) return; } - // ignore duplicate locations - if (allClients.containsKey(location)) { - return; - } + checkArgument(!allClients.containsKey(location), "location already exist: %s", location); checkState(!noMoreLocations, "No more locations already set"); buffer.addTask(taskId); diff --git a/core/trino-main/src/main/java/io/trino/operator/ExchangeOperator.java b/core/trino-main/src/main/java/io/trino/operator/ExchangeOperator.java index 3d1cb6ef8911..d84b3b0c4c48 100644 --- a/core/trino-main/src/main/java/io/trino/operator/ExchangeOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/ExchangeOperator.java @@ -29,6 +29,10 @@ import io.trino.spi.exchange.ExchangeId; import io.trino.split.RemoteSplit; import io.trino.sql.planner.plan.PlanNodeId; +import it.unimi.dsi.fastutil.ints.IntOpenHashSet; +import it.unimi.dsi.fastutil.ints.IntSet; + +import javax.annotation.concurrent.ThreadSafe; import java.util.Optional; import java.util.function.Supplier; @@ -56,6 +60,9 @@ public static class ExchangeOperatorFactory private ExchangeDataSource exchangeDataSource; private boolean closed; + private final NoMoreSplitsTracker noMoreSplitsTracker = new NoMoreSplitsTracker(); + private int nextOperatorInstanceId; + public ExchangeOperatorFactory( int operatorId, PlanNodeId sourceId, @@ -99,16 +106,28 @@ public SourceOperator createOperator(DriverContext driverContext) retryPolicy, exchangeManagerRegistry); } - return new ExchangeOperator( + int operatorInstanceId = nextOperatorInstanceId; + nextOperatorInstanceId++; + ExchangeOperator exchangeOperator = new ExchangeOperator( operatorContext, sourceId, exchangeDataSource, - serdeFactory.createPagesSerde()); + serdeFactory.createPagesSerde(), + noMoreSplitsTracker, + operatorInstanceId); + noMoreSplitsTracker.operatorAdded(operatorInstanceId); + return exchangeOperator; } @Override public void noMoreOperators() { + noMoreSplitsTracker.noMoreOperators(); + if (noMoreSplitsTracker.isNoMoreSplits()) { + if (exchangeDataSource != null) { + exchangeDataSource.noMoreInputs(); + } + } closed = true; } } @@ -117,18 +136,25 @@ public void noMoreOperators() private final PlanNodeId sourceId; private final ExchangeDataSource exchangeDataSource; private final PagesSerde serde; + private final NoMoreSplitsTracker noMoreSplitsTracker; + private final int operatorInstanceId; + private ListenableFuture isBlocked = NOT_BLOCKED; public ExchangeOperator( OperatorContext operatorContext, PlanNodeId sourceId, ExchangeDataSource exchangeDataSource, - PagesSerde serde) + PagesSerde serde, + NoMoreSplitsTracker noMoreSplitsTracker, + int operatorInstanceId) { this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); this.sourceId = requireNonNull(sourceId, "sourceId is null"); this.exchangeDataSource = requireNonNull(exchangeDataSource, "exchangeDataSource is null"); this.serde = requireNonNull(serde, "serde is null"); + this.noMoreSplitsTracker = requireNonNull(noMoreSplitsTracker, "noMoreSplitsTracker is null"); + this.operatorInstanceId = operatorInstanceId; operatorContext.setInfoSupplier(exchangeDataSource::getInfo); } @@ -154,7 +180,10 @@ public Supplier> addSplit(Split split) @Override public void noMoreSplits() { - exchangeDataSource.noMoreInputs(); + noMoreSplitsTracker.noMoreSplits(operatorInstanceId); + if (noMoreSplitsTracker.isNoMoreSplits()) { + exchangeDataSource.noMoreInputs(); + } } @Override @@ -220,4 +249,34 @@ public void close() { exchangeDataSource.close(); } + + @ThreadSafe + private static class NoMoreSplitsTracker + { + private final IntSet allOperators = new IntOpenHashSet(); + private final IntSet noMoreSplitsOperators = new IntOpenHashSet(); + private boolean noMoreOperators; + + public synchronized void operatorAdded(int operatorInstanceId) + { + checkState(!noMoreOperators, "noMoreOperators is set"); + allOperators.add(operatorInstanceId); + } + + public synchronized void noMoreOperators() + { + noMoreOperators = true; + } + + public synchronized void noMoreSplits(int operatorInstanceId) + { + checkState(allOperators.contains(operatorInstanceId), "operatorInstanceId not found: %s", operatorInstanceId); + noMoreSplitsOperators.add(operatorInstanceId); + } + + public synchronized boolean isNoMoreSplits() + { + return noMoreOperators && noMoreSplitsOperators.containsAll(allOperators); + } + } } diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/Query.java b/core/trino-main/src/main/java/io/trino/server/protocol/Query.java index 6406a12e047c..c107956ae9ca 100644 --- a/core/trino-main/src/main/java/io/trino/server/protocol/Query.java +++ b/core/trino-main/src/main/java/io/trino/server/protocol/Query.java @@ -187,7 +187,7 @@ public static Query create( Query result = new Query(session, slug, queryManager, queryInfoUrl, exchangeDataSource, dataProcessorExecutor, timeoutExecutor, blockEncodingSerde); - result.queryManager.addOutputInfoListener(result.getQueryId(), result::setQueryOutputInfo); + result.queryManager.setOutputInfoListener(result.getQueryId(), result::setQueryOutputInfo); result.queryManager.addStateChangeListener(result.getQueryId(), state -> { // Wait for the query info to become available and close the exchange client if there is no output stage for the query results to be pulled from. @@ -582,7 +582,7 @@ private synchronized void setQueryOutputInfo(QueryExecution.QueryOutputInfo outp types = outputInfo.getColumnTypes(); } - outputInfo.getInputs().forEach(exchangeDataSource::addInput); + outputInfo.drainInputs(exchangeDataSource::addInput); if (outputInfo.isNoMoreInputs()) { exchangeDataSource.noMoreInputs(); } diff --git a/core/trino-main/src/test/java/io/trino/operator/TestExchangeOperator.java b/core/trino-main/src/test/java/io/trino/operator/TestExchangeOperator.java index f2b820946cf7..afd4d1495774 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestExchangeOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestExchangeOperator.java @@ -272,6 +272,7 @@ private SourceOperator createExchangeOperator() SourceOperator operator = operatorFactory.createOperator(driverContext); assertEquals(getOnlyElement(operator.getOperatorContext().getNestedOperatorStats()).getUserMemoryReservation().toBytes(), 0); + operatorFactory.noMoreOperators(); return operator; } diff --git a/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSourceHandle.java b/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSourceHandle.java index 20130c578eb2..0e5cfbb7518c 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSourceHandle.java +++ b/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSourceHandle.java @@ -16,7 +16,7 @@ import io.trino.spi.Experimental; /* - * Implementation is expected to be Jackson serializable and include equals, hashCode and toString methods + * Implementation is expected to be Jackson serializable */ @Experimental(eta = "2023-01-01") public interface ExchangeSourceHandle @@ -26,13 +26,4 @@ public interface ExchangeSourceHandle long getDataSizeInBytes(); long getRetainedSizeInBytes(); - - @Override - boolean equals(Object obj); - - @Override - int hashCode(); - - @Override - String toString(); } diff --git a/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/FileSystemExchangeSourceHandle.java b/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/FileSystemExchangeSourceHandle.java index 74a1b1b8c500..79d4391092b9 100644 --- a/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/FileSystemExchangeSourceHandle.java +++ b/plugin/trino-exchange-filesystem/src/main/java/io/trino/plugin/exchange/filesystem/FileSystemExchangeSourceHandle.java @@ -20,9 +20,7 @@ import io.trino.spi.exchange.ExchangeSourceHandle; import org.openjdk.jol.info.ClassLayout; -import java.util.Arrays; import java.util.List; -import java.util.Objects; import java.util.Optional; import static com.google.common.base.MoreObjects.toStringHelper; @@ -85,28 +83,6 @@ public Optional getSecretKey() return secretKey; } - @Override - public boolean equals(Object o) - { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - FileSystemExchangeSourceHandle that = (FileSystemExchangeSourceHandle) o; - if (secretKey.isPresent() && that.secretKey.isPresent()) { - return partitionId == that.getPartitionId() && Arrays.equals(secretKey.get(), that.secretKey.get()); - } - return partitionId == that.getPartitionId() && secretKey.isEmpty() && that.secretKey.isEmpty(); - } - - @Override - public int hashCode() - { - return Objects.hash(partitionId, files, secretKey); - } - @Override public String toString() {