diff --git a/presto-main/src/main/java/com/facebook/presto/spiller/LocalTempStorage.java b/presto-main/src/main/java/com/facebook/presto/spiller/LocalTempStorage.java index be85a98a2588f..78db62b3c8f88 100644 --- a/presto-main/src/main/java/com/facebook/presto/spiller/LocalTempStorage.java +++ b/presto-main/src/main/java/com/facebook/presto/spiller/LocalTempStorage.java @@ -18,6 +18,7 @@ import com.facebook.presto.common.io.DataSink; import com.facebook.presto.common.io.OutputStreamDataSink; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.storage.StorageCapabilities; import com.facebook.presto.spi.storage.TempDataOperationContext; import com.facebook.presto.spi.storage.TempDataSink; import com.facebook.presto.spi.storage.TempStorage; @@ -31,6 +32,8 @@ import java.io.IOException; import java.io.InputStream; +import java.net.URI; +import java.net.URISyntaxException; import java.nio.file.DirectoryStream; import java.nio.file.FileStore; import java.nio.file.Files; @@ -43,6 +46,7 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.lang.String.format; +import static java.nio.charset.StandardCharsets.UTF_8; import static java.nio.file.Files.createDirectories; import static java.nio.file.Files.delete; import static java.nio.file.Files.getFileStore; @@ -118,6 +122,31 @@ public void remove(TempDataOperationContext context, TempStorageHandle handle) Files.delete(((LocalTempStorageHandle) handle).getFilePath()); } + @Override + public byte[] serializeHandle(TempStorageHandle storageHandle) + { + URI uri = ((LocalTempStorageHandle) storageHandle).getFilePath().toUri(); + return uri.toString().getBytes(UTF_8); + } + + @Override + public TempStorageHandle deserialize(byte[] serializedStorageHandle) + { + String uriString = new String(serializedStorageHandle, UTF_8); + try { + return new LocalTempStorageHandle(Paths.get(new URI(uriString))); + } + catch (URISyntaxException e) { + throw new IllegalArgumentException("Invalid URI: " + uriString, e); + } + } + + @Override + public List getStorageCapabilities() + { + return ImmutableList.of(); + } + private static void cleanupOldSpillFiles(Path path) { try (DirectoryStream stream = newDirectoryStream(path, SPILL_FILE_GLOB)) { @@ -178,6 +207,12 @@ public Path getFilePath() { return filePath; } + + @Override + public String toString() + { + return filePath.toString(); + } } private static class LocalTempDataSink diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkBroadcastDependency.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkBroadcastDependency.java new file mode 100644 index 0000000000000..23fb4596b6465 --- /dev/null +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkBroadcastDependency.java @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spark; + +import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskOutput; +import org.apache.spark.SparkException; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.broadcast.Broadcast; + +import java.util.List; +import java.util.concurrent.TimeoutException; + +public interface PrestoSparkBroadcastDependency +{ + Broadcast> executeBroadcast(JavaSparkContext sparkContext) + throws SparkException, TimeoutException; + + void destroy(); +} diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkConfig.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkConfig.java index 6689f2b4b2e72..7c9d0a28712e0 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkConfig.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkConfig.java @@ -21,6 +21,7 @@ import static io.airlift.units.DataSize.Unit.GIGABYTE; import static io.airlift.units.DataSize.Unit.KILOBYTE; +import static io.airlift.units.DataSize.Unit.MEGABYTE; public class PrestoSparkConfig { @@ -30,6 +31,9 @@ public class PrestoSparkConfig private int initialSparkPartitionCount = 16; private DataSize maxSplitsDataSizePerSparkPartition = new DataSize(2, GIGABYTE); private DataSize shuffleOutputTargetAverageRowSize = new DataSize(1, KILOBYTE); + private boolean storageBasedBroadcastJoinEnabled; + private DataSize storageBasedBroadcastJoinWriteBufferSize = new DataSize(24, MEGABYTE); + private String storageBasedBroadcastJoinStorage = "local"; public boolean isSparkPartitionCountAutoTuneEnabled() { @@ -109,4 +113,43 @@ public PrestoSparkConfig setShuffleOutputTargetAverageRowSize(DataSize shuffleOu this.shuffleOutputTargetAverageRowSize = shuffleOutputTargetAverageRowSize; return this; } + + public boolean isStorageBasedBroadcastJoinEnabled() + { + return storageBasedBroadcastJoinEnabled; + } + + @Config("spark.storage-based-broadcast-join-enabled") + @ConfigDescription("Distribute broadcast hashtable to workers using storage") + public PrestoSparkConfig setStorageBasedBroadcastJoinEnabled(boolean storageBasedBroadcastJoinEnabled) + { + this.storageBasedBroadcastJoinEnabled = storageBasedBroadcastJoinEnabled; + return this; + } + + public DataSize getStorageBasedBroadcastJoinWriteBufferSize() + { + return storageBasedBroadcastJoinWriteBufferSize; + } + + @Config("spark.storage-based-broadcast-join-write-buffer-size") + @ConfigDescription("Maximum size in bytes to buffer before flushing pages to disk") + public PrestoSparkConfig setStorageBasedBroadcastJoinWriteBufferSize(DataSize storageBasedBroadcastJoinWriteBufferSize) + { + this.storageBasedBroadcastJoinWriteBufferSize = storageBasedBroadcastJoinWriteBufferSize; + return this; + } + + public String getStorageBasedBroadcastJoinStorage() + { + return storageBasedBroadcastJoinStorage; + } + + @Config("spark.storage-based-broadcast-join-storage") + @ConfigDescription("TempStorage to use for dumping broadcast table") + public PrestoSparkConfig setStorageBasedBroadcastJoinStorage(String storageBasedBroadcastJoinStorage) + { + this.storageBasedBroadcastJoinStorage = storageBasedBroadcastJoinStorage; + return this; + } } diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkMemoryBasedBroadcastDependency.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkMemoryBasedBroadcastDependency.java new file mode 100644 index 0000000000000..64af8c6f11b4a --- /dev/null +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkMemoryBasedBroadcastDependency.java @@ -0,0 +1,85 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spark; + +import com.facebook.presto.spark.classloader_interface.PrestoSparkSerializedPage; +import io.airlift.units.DataSize; +import org.apache.spark.SparkException; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.broadcast.Broadcast; +import scala.Tuple2; + +import java.util.List; +import java.util.concurrent.TimeoutException; + +import static com.facebook.presto.ExceededMemoryLimitException.exceededLocalBroadcastMemoryLimit; +import static com.facebook.presto.spark.util.PrestoSparkUtils.computeNextTimeout; +import static io.airlift.units.DataSize.succinctBytes; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.stream.Collectors.toList; + +public class PrestoSparkMemoryBasedBroadcastDependency + implements PrestoSparkBroadcastDependency +{ + private final RddAndMore broadcastDependency; + private final DataSize maxBroadcastSize; + private final long queryCompletionDeadline; + private Broadcast> broadcastVariable; + + public PrestoSparkMemoryBasedBroadcastDependency(RddAndMore broadcastDependency, DataSize maxBroadcastSize, long queryCompletionDeadline) + { + this.broadcastDependency = requireNonNull(broadcastDependency, "broadcastDependency cannot be null"); + this.maxBroadcastSize = requireNonNull(maxBroadcastSize, "maxBroadcastSize cannot be null"); + this.queryCompletionDeadline = queryCompletionDeadline; + } + + @Override + public Broadcast> executeBroadcast(JavaSparkContext sparkContext) + throws SparkException, TimeoutException + { + List broadcastValue = broadcastDependency.collectAndDestroyDependenciesWithTimeout(computeNextTimeout(queryCompletionDeadline), MILLISECONDS).stream() + .map(Tuple2::_2) + .collect(toList()); + + long compressedBroadcastSizeInBytes = broadcastValue.stream() + .mapToInt(page -> page.getBytes().length) + .sum(); + long uncompressedBroadcastSizeInBytes = broadcastValue.stream() + .mapToInt(page -> page.getUncompressedSizeInBytes()) + .sum(); + + long maxBroadcastSizeInBytes = maxBroadcastSize.toBytes(); + + if (compressedBroadcastSizeInBytes > maxBroadcastSizeInBytes) { + throw exceededLocalBroadcastMemoryLimit(maxBroadcastSize, format("Compressed broadcast size: %s", succinctBytes(compressedBroadcastSizeInBytes))); + } + + if (uncompressedBroadcastSizeInBytes > maxBroadcastSizeInBytes) { + throw exceededLocalBroadcastMemoryLimit(maxBroadcastSize, format("Uncompressed broadcast size: %s", succinctBytes(uncompressedBroadcastSizeInBytes))); + } + + broadcastVariable = sparkContext.broadcast(broadcastValue); + return broadcastVariable; + } + + @Override + public void destroy() + { + if (broadcastVariable != null) { + broadcastVariable.destroy(); + } + } +} diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java index 2cc9b2b894cdd..81ef89041042a 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java @@ -96,6 +96,7 @@ import com.facebook.presto.server.SessionPropertyDefaults; import com.facebook.presto.server.security.ServerSecurityModule; import com.facebook.presto.spark.classloader_interface.SparkProcessType; +import com.facebook.presto.spark.execution.PrestoSparkBroadcastTableCacheManager; import com.facebook.presto.spark.execution.PrestoSparkExecutionExceptionFactory; import com.facebook.presto.spark.execution.PrestoSparkTaskExecutorFactory; import com.facebook.presto.spark.node.PrestoSparkInternalNodeManager; @@ -423,6 +424,7 @@ protected void setup(Binder binder) binder.bind(PrestoSparkTaskExecutorFactory.class).in(Scopes.SINGLETON); binder.bind(PrestoSparkQueryExecutionFactory.class).in(Scopes.SINGLETON); binder.bind(PrestoSparkService.class).in(Scopes.SINGLETON); + binder.bind(PrestoSparkBroadcastTableCacheManager.class).in(Scopes.SINGLETON); // extra credentials and authenticator for Presto-on-Spark newSetBinder(binder, PrestoSparkCredentialsProvider.class); diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkQueryExecutionFactory.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkQueryExecutionFactory.java index 2a0a6b2a7421d..2387fae7bad20 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkQueryExecutionFactory.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkQueryExecutionFactory.java @@ -60,6 +60,7 @@ import com.facebook.presto.spark.classloader_interface.PrestoSparkSession; import com.facebook.presto.spark.classloader_interface.PrestoSparkShuffleStats; import com.facebook.presto.spark.classloader_interface.PrestoSparkShuffleStats.Operation; +import com.facebook.presto.spark.classloader_interface.PrestoSparkStorageHandle; import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskExecutorFactoryProvider; import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskInputs; import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskOutput; @@ -83,9 +84,13 @@ import com.facebook.presto.spi.page.PagesSerde; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.resourceGroups.ResourceGroupId; +import com.facebook.presto.spi.storage.StorageCapabilities; +import com.facebook.presto.spi.storage.TempDataOperationContext; +import com.facebook.presto.spi.storage.TempStorage; import com.facebook.presto.sql.planner.PlanFragment; import com.facebook.presto.sql.planner.SubPlan; import com.facebook.presto.sql.planner.plan.PlanFragmentId; +import com.facebook.presto.storage.TempStorageManager; import com.facebook.presto.transaction.TransactionId; import com.facebook.presto.transaction.TransactionInfo; import com.facebook.presto.transaction.TransactionManager; @@ -95,7 +100,6 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.ListMultimap; -import com.google.common.util.concurrent.UncheckedExecutionException; import io.airlift.units.DataSize; import io.airlift.units.Duration; import org.apache.spark.SparkContext; @@ -126,12 +130,10 @@ import java.util.OptionalLong; import java.util.Set; import java.util.TreeMap; -import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import static com.facebook.airlift.concurrent.MoreFutures.getFutureValue; -import static com.facebook.presto.ExceededMemoryLimitException.exceededLocalBroadcastMemoryLimit; import static com.facebook.presto.SystemSessionProperties.getQueryMaxBroadcastMemory; import static com.facebook.presto.SystemSessionProperties.getQueryMaxExecutionTime; import static com.facebook.presto.SystemSessionProperties.getQueryMaxRunTime; @@ -144,18 +146,23 @@ import static com.facebook.presto.execution.scheduler.StreamingPlanSection.extractStreamingSections; import static com.facebook.presto.execution.scheduler.TableWriteInfo.createTableWriteInfo; import static com.facebook.presto.server.protocol.QueryResourceUtil.toStatementStats; +import static com.facebook.presto.spark.PrestoSparkSessionProperties.isStorageBasedBroadcastJoinEnabled; import static com.facebook.presto.spark.SparkErrorCode.EXCEEDED_SPARK_DRIVER_MAX_RESULT_SIZE; import static com.facebook.presto.spark.SparkErrorCode.GENERIC_SPARK_ERROR; import static com.facebook.presto.spark.SparkErrorCode.SPARK_EXECUTOR_LOST; import static com.facebook.presto.spark.SparkErrorCode.SPARK_EXECUTOR_OOM; +import static com.facebook.presto.spark.SparkErrorCode.UNSUPPORTED_STORAGE_TYPE; import static com.facebook.presto.spark.classloader_interface.ScalaUtils.collectScalaIterator; import static com.facebook.presto.spark.classloader_interface.ScalaUtils.emptyScalaIterator; +import static com.facebook.presto.spark.util.PrestoSparkUtils.computeNextTimeout; import static com.facebook.presto.spark.util.PrestoSparkUtils.createPagesSerde; +import static com.facebook.presto.spark.util.PrestoSparkUtils.getActionResultWithTimeout; import static com.facebook.presto.spark.util.PrestoSparkUtils.toSerializedPage; import static com.facebook.presto.spi.StandardErrorCode.EXCEEDED_TIME_LIMIT; import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.facebook.presto.spi.connector.ConnectorCapabilities.SUPPORTS_PAGE_SINK_COMMIT; +import static com.facebook.presto.spi.storage.StorageCapabilities.REMOTELY_ACCESSIBLE; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.COORDINATOR_DISTRIBUTION; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION; import static com.facebook.presto.sql.planner.planPrinter.PlanPrinter.textDistributedPlan; @@ -163,7 +170,6 @@ import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; -import static com.google.common.base.Throwables.propagateIfPossible; import static com.google.common.base.Ticker.systemTicker; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -171,12 +177,11 @@ import static com.google.common.util.concurrent.Futures.getUnchecked; import static io.airlift.units.DataSize.succinctBytes; import static java.lang.Math.max; -import static java.lang.String.format; import static java.nio.file.Files.notExists; import static java.util.Collections.unmodifiableList; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; -import static java.util.stream.Collectors.toList; +import static org.apache.spark.util.Utils.isLocalMaster; public class PrestoSparkQueryExecutionFactory implements IPrestoSparkQueryExecutionFactory @@ -206,6 +211,8 @@ public class PrestoSparkQueryExecutionFactory private final Set credentialsProviders; private final Set authenticatorProviders; + private final TempStorageManager tempStorageManager; + private final String storageBasedBroadcastJoinStorage; @Inject public PrestoSparkQueryExecutionFactory( @@ -230,7 +237,9 @@ public PrestoSparkQueryExecutionFactory( SessionPropertyDefaults sessionPropertyDefaults, WarningCollectorFactory warningCollectorFactory, Set credentialsProviders, - Set authenticatorProviders) + Set authenticatorProviders, + TempStorageManager tempStorageManager, + PrestoSparkConfig prestoSparkConfig) { this.queryIdGenerator = requireNonNull(queryIdGenerator, "queryIdGenerator is null"); this.sessionSupplier = requireNonNull(sessionSupplier, "sessionSupplier is null"); @@ -254,6 +263,8 @@ public PrestoSparkQueryExecutionFactory( this.warningCollectorFactory = requireNonNull(warningCollectorFactory, "warningCollectorFactory is null"); this.credentialsProviders = ImmutableSet.copyOf(requireNonNull(credentialsProviders, "credentialsProviders is null")); this.authenticatorProviders = ImmutableSet.copyOf(requireNonNull(authenticatorProviders, "authenticatorProviders is null")); + this.tempStorageManager = requireNonNull(tempStorageManager, "tempStorageManager is null"); + this.storageBasedBroadcastJoinStorage = requireNonNull(prestoSparkConfig, "prestoSparkConfig is null").getStorageBasedBroadcastJoinStorage(); } @Override @@ -329,7 +340,7 @@ public IPrestoSparkQueryExecution create( taskInfoCollector.register(sparkContext, new Some<>("taskInfoCollector"), false); CollectionAccumulator shuffleStatsCollector = new CollectionAccumulator<>(); shuffleStatsCollector.register(sparkContext, new Some<>("shuffleStatsCollector"), false); - + TempStorage tempStorage = tempStorageManager.getTempStorage(storageBasedBroadcastJoinStorage); queryStateTimer.endAnalysis(); return new PrestoSparkQueryExecution( @@ -358,7 +369,8 @@ public IPrestoSparkQueryExecution create( queryTimeout, queryCompletionDeadline, queryStatusInfoOutputPath, - queryDataOutputPath); + queryDataOutputPath, + tempStorage); } catch (Throwable executionFailure) { queryStateTimer.beginFinishing(); @@ -711,6 +723,7 @@ public static class PrestoSparkQueryExecution private final Optional queryDataOutputPath; private final long queryCompletionDeadline; + private final TempStorage tempStorage; private PrestoSparkQueryExecution( JavaSparkContext sparkContext, @@ -738,7 +751,8 @@ private PrestoSparkQueryExecution( Duration queryTimeout, long queryCompletionDeadline, Optional queryStatusInfoOutputPath, - Optional queryDataOutputPath) + Optional queryDataOutputPath, + TempStorage tempStorage) { this.sparkContext = requireNonNull(sparkContext, "sparkContext is null"); this.session = requireNonNull(session, "session is null"); @@ -767,6 +781,7 @@ private PrestoSparkQueryExecution( this.queryCompletionDeadline = queryCompletionDeadline; this.queryStatusInfoOutputPath = requireNonNull(queryStatusInfoOutputPath, "queryStatusInfoOutputPath is null"); this.queryDataOutputPath = requireNonNull(queryDataOutputPath, "queryDataOutputPath is null"); + this.tempStorage = requireNonNull(tempStorage, "tempStorage is null"); } @Override @@ -915,7 +930,7 @@ private List> doExecute(Su Map>>> inputFutures = inputRdds.entrySet().stream() .collect(toImmutableMap(entry -> entry.getKey().toString(), entry -> entry.getValue().getRdd().collectAsync())); - waitForActionsCompletionWithTimeout(inputFutures.values(), computeNextTimeout(), MILLISECONDS); + waitForActionsCompletionWithTimeout(inputFutures.values(), computeNextTimeout(queryCompletionDeadline), MILLISECONDS); Map> inputs = inputFutures.entrySet().stream() .collect(toImmutableMap( @@ -935,46 +950,46 @@ private List> doExecute(Su } RddAndMore rootRdd = createRdd(root, PrestoSparkSerializedPage.class); - return rootRdd.collectAndDestroyDependenciesWithTimeout(computeNextTimeout(), MILLISECONDS); + return rootRdd.collectAndDestroyDependenciesWithTimeout(computeNextTimeout(queryCompletionDeadline), MILLISECONDS); } private RddAndMore createRdd(SubPlan subPlan, Class outputType) throws SparkException, TimeoutException { ImmutableMap.Builder> rddInputs = ImmutableMap.builder(); - ImmutableMap.Builder>> broadcastInputs = ImmutableMap.builder(); - ImmutableList.Builder> broadcastDependencies = ImmutableList.builder(); + ImmutableMap.Builder>> broadcastInputs = ImmutableMap.builder(); + ImmutableList.Builder broadcastDependencies = ImmutableList.builder(); for (SubPlan child : subPlan.getChildren()) { PlanFragment childFragment = child.getFragment(); if (childFragment.getPartitioningScheme().getPartitioning().getHandle().equals(FIXED_BROADCAST_DISTRIBUTION)) { - RddAndMore childRdd = createRdd(child, PrestoSparkSerializedPage.class); - - // TODO: The driver might still OOM on a very large broadcast, think of how to prevent that from happening - List broadcastPages = childRdd.collectAndDestroyDependenciesWithTimeout(computeNextTimeout(), MILLISECONDS).stream() - .map(Tuple2::_2) - .collect(toList()); - - int compressedBroadcastSizeInBytes = broadcastPages.stream() - .mapToInt(page -> page.getBytes().length) - .sum(); - int uncompressedBroadcastSizeInBytes = broadcastPages.stream() - .mapToInt(PrestoSparkSerializedPage::getUncompressedSizeInBytes) - .sum(); - DataSize maxBroadcastSize = getQueryMaxBroadcastMemory(session); - long maxBroadcastSizeInBytes = maxBroadcastSize.toBytes(); - - if (compressedBroadcastSizeInBytes > maxBroadcastSizeInBytes) { - throw exceededLocalBroadcastMemoryLimit(maxBroadcastSize, format("Compressed broadcast size: %s", succinctBytes(compressedBroadcastSizeInBytes))); + PrestoSparkBroadcastDependency broadcastDependency; + if (isStorageBasedBroadcastJoinEnabled(session)) { + validateStorageCapabilities(tempStorage); + RddAndMore childRdd = createRdd(child, PrestoSparkStorageHandle.class); + TempDataOperationContext tempDataOperationContext = new TempDataOperationContext( + session.getSource(), + session.getQueryId().getId(), + session.getClientInfo(), + session.getIdentity()); + + broadcastDependency = new PrestoSparkStorageBasedBroadcastDependency( + childRdd, + getQueryMaxBroadcastMemory(session), + queryCompletionDeadline, + tempStorage, + tempDataOperationContext); } - - if (uncompressedBroadcastSizeInBytes > maxBroadcastSizeInBytes) { - throw exceededLocalBroadcastMemoryLimit(maxBroadcastSize, format("Uncompressed broadcast size: %s", succinctBytes(compressedBroadcastSizeInBytes))); + else { + RddAndMore childRdd = createRdd(child, PrestoSparkSerializedPage.class); + broadcastDependency = new PrestoSparkMemoryBasedBroadcastDependency( + childRdd, + getQueryMaxBroadcastMemory(session), + queryCompletionDeadline); } - Broadcast> broadcast = sparkContext.broadcast(broadcastPages); - broadcastInputs.put(childFragment.getId(), broadcast); - broadcastDependencies.add(broadcast); + broadcastInputs.put(childFragment.getId(), broadcastDependency.executeBroadcast(sparkContext)); + broadcastDependencies.add(broadcastDependency); } else { RddAndMore childRdd = createRdd(child, PrestoSparkMutableRow.class); @@ -996,6 +1011,15 @@ private RddAndMore createRdd(SubPlan subPla return new RddAndMore<>(rdd, broadcastDependencies.build()); } + private void validateStorageCapabilities(TempStorage tempStorage) + { + boolean isLocalMode = isLocalMaster(sparkContext.getConf()); + List storageCapabilities = tempStorage.getStorageCapabilities(); + if (!isLocalMode && !storageCapabilities.contains(REMOTELY_ACCESSIBLE)) { + throw new PrestoException(UNSUPPORTED_STORAGE_TYPE, "Configured TempStorage does not support remote access required for distributing broadcast tables."); + } + } + private void queryCompletedEvent(Optional failureInfo, OptionalLong updateCount) { List serializedTaskInfos = taskInfoCollector.value(); @@ -1089,16 +1113,6 @@ private void logShuffleStatsSummary(ShuffleStatsKey key, List void waitForActionsCompletionWithTimeout(Collection> actions, long timeout, TimeUnit timeUnit) @@ -1124,79 +1138,6 @@ private static void waitForActionsCompletionWithTimeout(Collection T getActionResultWithTimeout(JavaFutureAction action, long timeout, TimeUnit timeUnit) - throws SparkException, TimeoutException - { - long deadline = System.currentTimeMillis() + timeUnit.toMillis(timeout); - try { - while (true) { - long nextTimeoutInMillis = deadline - System.currentTimeMillis(); - if (nextTimeoutInMillis <= 0) { - throw new TimeoutException(); - } - try { - return action.get(nextTimeoutInMillis, MILLISECONDS); - } - catch (TimeoutException e) { - // guard against spurious wakeup - if (deadline - System.currentTimeMillis() <= 0) { - throw e; - } - } - } - } - catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException(e); - } - catch (ExecutionException e) { - propagateIfPossible(e.getCause(), SparkException.class); - propagateIfPossible(e.getCause(), RuntimeException.class); - - // this should never happen - throw new UncheckedExecutionException(e.getCause()); - } - finally { - if (!action.isDone()) { - action.cancel(true); - } - } - } - - private static class RddAndMore - { - private final JavaPairRDD rdd; - private final List> broadcastDependencies; - - private boolean collected; - - private RddAndMore(JavaPairRDD rdd, List> broadcastDependencies) - { - this.rdd = requireNonNull(rdd, "rdd is null"); - this.broadcastDependencies = ImmutableList.copyOf(requireNonNull(broadcastDependencies, "broadcastDependencies is null")); - } - - public List> collectAndDestroyDependenciesWithTimeout(long timeout, TimeUnit timeUnit) - throws SparkException, TimeoutException - { - checkState(!collected, "already collected"); - collected = true; - List> result = getActionResultWithTimeout(rdd.collectAsync(), timeout, timeUnit); - broadcastDependencies.forEach(Broadcast::destroy); - return result; - } - - public JavaPairRDD getRdd() - { - return rdd; - } - - public List> getBroadcastDependencies() - { - return broadcastDependencies; - } - } - private static class ShuffleStatsKey implements Comparable { diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkSessionProperties.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkSessionProperties.java index 914aad8f2c62b..aae16d30fac95 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkSessionProperties.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkSessionProperties.java @@ -34,6 +34,8 @@ public class PrestoSparkSessionProperties public static final String SPARK_INITIAL_PARTITION_COUNT = "spark_initial_partition_count"; public static final String MAX_SPLITS_DATA_SIZE_PER_SPARK_PARTITION = "max_splits_data_size_per_spark_partition"; public static final String SHUFFLE_OUTPUT_TARGET_AVERAGE_ROW_SIZE = "shuffle_output_target_average_row_size"; + public static final String STORAGE_BASED_BROADCAST_JOIN_ENABLED = "storage_based_broadcast_join_enabled"; + public static final String STORAGE_BASED_BROADCAST_JOIN_WRITE_BUFFER_SIZE = "storage_based_broadcast_join_write_buffer_size"; private final List> sessionProperties; @@ -70,6 +72,16 @@ public PrestoSparkSessionProperties(PrestoSparkConfig prestoSparkConfig) SHUFFLE_OUTPUT_TARGET_AVERAGE_ROW_SIZE, "Target average size for row entries produced by Presto on Spark for shuffle", prestoSparkConfig.getShuffleOutputTargetAverageRowSize(), + false), + booleanProperty( + STORAGE_BASED_BROADCAST_JOIN_ENABLED, + "Use storage for distributing broadcast table", + prestoSparkConfig.isStorageBasedBroadcastJoinEnabled(), + false), + dataSizeProperty( + STORAGE_BASED_BROADCAST_JOIN_WRITE_BUFFER_SIZE, + "Maximum size in bytes to buffer before flushing pages to disk", + prestoSparkConfig.getStorageBasedBroadcastJoinWriteBufferSize(), false)); } @@ -107,4 +119,14 @@ public static DataSize getShuffleOutputTargetAverageRowSize(Session session) { return session.getSystemProperty(SHUFFLE_OUTPUT_TARGET_AVERAGE_ROW_SIZE, DataSize.class); } + + public static boolean isStorageBasedBroadcastJoinEnabled(Session session) + { + return session.getSystemProperty(STORAGE_BASED_BROADCAST_JOIN_ENABLED, Boolean.class); + } + + public static DataSize getStorageBasedBroadcastJoinWriteBufferSize(Session session) + { + return session.getSystemProperty(STORAGE_BASED_BROADCAST_JOIN_WRITE_BUFFER_SIZE, DataSize.class); + } } diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkStorageBasedBroadcastDependency.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkStorageBasedBroadcastDependency.java new file mode 100644 index 0000000000000..6047bdf62fc30 --- /dev/null +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkStorageBasedBroadcastDependency.java @@ -0,0 +1,119 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spark; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.spark.classloader_interface.PrestoSparkStorageHandle; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.storage.TempDataOperationContext; +import com.facebook.presto.spi.storage.TempStorage; +import com.facebook.presto.spi.storage.TempStorageHandle; +import io.airlift.units.DataSize; +import org.apache.spark.SparkException; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.broadcast.Broadcast; +import scala.Tuple2; + +import java.io.IOException; +import java.util.List; +import java.util.concurrent.TimeoutException; + +import static com.facebook.presto.ExceededMemoryLimitException.exceededLocalBroadcastMemoryLimit; +import static com.facebook.presto.spark.SparkErrorCode.STORAGE_ERROR; +import static com.facebook.presto.spark.util.PrestoSparkUtils.computeNextTimeout; +import static io.airlift.units.DataSize.succinctBytes; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.stream.Collectors.toList; + +public class PrestoSparkStorageBasedBroadcastDependency + implements PrestoSparkBroadcastDependency +{ + private static final Logger log = Logger.get(PrestoSparkStorageBasedBroadcastDependency.class); + + private final RddAndMore broadcastDependency; + private final DataSize maxBroadcastSize; + private final long queryCompletionDeadline; + private final TempStorage tempStorage; + private final TempDataOperationContext tempDataOperationContext; + + private Broadcast> broadcastVariable; + + public PrestoSparkStorageBasedBroadcastDependency(RddAndMore broadcastDependency, DataSize maxBroadcastSize, long queryCompletionDeadline, TempStorage tempStorage, TempDataOperationContext tempDataOperationContext) + { + this.broadcastDependency = requireNonNull(broadcastDependency, "broadcastDependency cannot be null"); + this.maxBroadcastSize = requireNonNull(maxBroadcastSize, "maxBroadcastSize cannot be null"); + this.queryCompletionDeadline = queryCompletionDeadline; + this.tempStorage = requireNonNull(tempStorage, "tempStorage cannot be null"); + this.tempDataOperationContext = requireNonNull(tempDataOperationContext, "tempDataOperationContext cannot be null"); + } + + @Override + public Broadcast> executeBroadcast(JavaSparkContext sparkContext) + throws SparkException, TimeoutException + { + List broadcastValue = broadcastDependency.collectAndDestroyDependenciesWithTimeout(computeNextTimeout(queryCompletionDeadline), MILLISECONDS).stream() + .map(Tuple2::_2) + .collect(toList()); + + long compressedBroadcastSizeInBytes = broadcastValue.stream() + .mapToLong(metadata -> metadata.getCompressedSizeInBytes()) + .sum(); + long uncompressedBroadcastSizeInBytes = broadcastValue.stream() + .mapToLong(metadata -> metadata.getUncompressedSizeInBytes()) + .sum(); + + log.info("Got back %d pages. compressedBroadcastSizeInBytes: %d; uncompressedBroadcastSizeInBytes: %d", + broadcastValue.size(), + compressedBroadcastSizeInBytes, + uncompressedBroadcastSizeInBytes); + + long maxBroadcastSizeInBytes = maxBroadcastSize.toBytes(); + + if (compressedBroadcastSizeInBytes > maxBroadcastSizeInBytes) { + throw exceededLocalBroadcastMemoryLimit(maxBroadcastSize, format("Compressed broadcast size: %s", succinctBytes(compressedBroadcastSizeInBytes))); + } + + if (uncompressedBroadcastSizeInBytes > maxBroadcastSizeInBytes) { + throw exceededLocalBroadcastMemoryLimit(maxBroadcastSize, format("Uncompressed broadcast size: %s", succinctBytes(uncompressedBroadcastSizeInBytes))); + } + + broadcastVariable = sparkContext.broadcast(broadcastValue); + return broadcastVariable; + } + + @Override + public void destroy() + { + if (broadcastVariable == null) { + return; + } + + try { + // Delete the files + for (PrestoSparkStorageHandle diskPage : broadcastVariable.getValue()) { + TempStorageHandle storageHandle = tempStorage.deserialize(diskPage.getSerializedStorageHandle()); + tempStorage.remove(tempDataOperationContext, storageHandle); + log.info("Deleted broadcast spill file: " + storageHandle.toString()); + } + } + catch (IOException e) { + throw new PrestoException(STORAGE_ERROR, "Unable to delete broadcast spill file", e); + } + + // Destroy broadcast variable + broadcastVariable.destroy(); + } +} diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/RddAndMore.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/RddAndMore.java new file mode 100644 index 0000000000000..4c6f6bc754ea6 --- /dev/null +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/RddAndMore.java @@ -0,0 +1,65 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spark; + +import com.facebook.presto.spark.classloader_interface.MutablePartitionId; +import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskOutput; +import com.google.common.collect.ImmutableList; +import org.apache.spark.SparkException; +import org.apache.spark.api.java.JavaPairRDD; +import scala.Tuple2; + +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import static com.facebook.presto.spark.util.PrestoSparkUtils.getActionResultWithTimeout; +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +public class RddAndMore +{ + private final JavaPairRDD rdd; + private final List broadcastDependencies; + + private boolean collected; + + public RddAndMore( + JavaPairRDD rdd, + List broadcastDependencies) + { + this.rdd = requireNonNull(rdd, "rdd is null"); + this.broadcastDependencies = ImmutableList.copyOf(requireNonNull(broadcastDependencies, "broadcastDependencies is null")); + } + + public List> collectAndDestroyDependenciesWithTimeout(long timeout, TimeUnit timeUnit) + throws SparkException, TimeoutException + { + checkState(!collected, "already collected"); + collected = true; + List> result = getActionResultWithTimeout(rdd.collectAsync(), timeout, timeUnit); + broadcastDependencies.forEach(PrestoSparkBroadcastDependency::destroy); + return result; + } + + public JavaPairRDD getRdd() + { + return rdd; + } + + public List getBroadcastDependencies() + { + return broadcastDependencies; + } +} diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/PrestoSparkBroadcastTableCacheManager.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/PrestoSparkBroadcastTableCacheManager.java new file mode 100644 index 0000000000000..c39a2dd4743b4 --- /dev/null +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/PrestoSparkBroadcastTableCacheManager.java @@ -0,0 +1,98 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spark.execution; + +import com.facebook.presto.common.Page; +import com.facebook.presto.execution.StageId; +import com.facebook.presto.spi.plan.PlanNodeId; + +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static java.util.Objects.requireNonNull; + +public class PrestoSparkBroadcastTableCacheManager +{ + // Currently we cache HT from a single stage. When a task from another stage is scheduled, the cache will be cleared + private final Map>> cache = new HashMap<>(); + private long cacheSizeInBytes; + + public synchronized void removeCachedTablesForStagesOtherThan(StageId stageId) + { + Iterator>>> iterator = cache.entrySet().iterator(); + while (iterator.hasNext()) { + Map.Entry>> entry = iterator.next(); + if (!entry.getKey().getStageId().equals(stageId)) { + cacheSizeInBytes -= entry.getValue().stream().mapToLong(pageList -> pageList.stream().mapToLong(Page::getRetainedSizeInBytes).sum()).sum(); + iterator.remove(); + } + } + } + + public synchronized List> getCachedBroadcastTable(StageId stageId, PlanNodeId planNodeId) + { + return cache.get(new BroadcastTableCacheKey(stageId, planNodeId)); + } + + public synchronized void cache(StageId stageId, PlanNodeId planNodeId, List> broadcastTable) + { + cache.put(new BroadcastTableCacheKey(stageId, planNodeId), broadcastTable); + cacheSizeInBytes += broadcastTable.stream().mapToLong(pageList -> pageList.stream().mapToLong(Page::getRetainedSizeInBytes).sum()).sum(); + } + + public synchronized long getCacheSizeInBytes() + { + return cacheSizeInBytes; + } + + private static class BroadcastTableCacheKey + { + private final StageId stageId; + private final PlanNodeId planNodeId; + + public BroadcastTableCacheKey(StageId stageId, PlanNodeId planNodeId) + { + this.stageId = requireNonNull(stageId, "stageId is null"); + this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + BroadcastTableCacheKey that = (BroadcastTableCacheKey) o; + return Objects.equals(stageId, that.stageId) && + Objects.equals(planNodeId, that.planNodeId); + } + + @Override + public int hashCode() + { + return Objects.hash(stageId, planNodeId); + } + + public StageId getStageId() + { + return stageId; + } + } +} diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/PrestoSparkDiskPageInput.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/PrestoSparkDiskPageInput.java new file mode 100644 index 0000000000000..cd3362c709d0b --- /dev/null +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/PrestoSparkDiskPageInput.java @@ -0,0 +1,176 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spark.execution; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.common.Page; +import com.facebook.presto.execution.StageId; +import com.facebook.presto.spark.classloader_interface.PrestoSparkStorageHandle; +import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskOutput; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.page.PagesSerde; +import com.facebook.presto.spi.page.SerializedPage; +import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.storage.TempDataOperationContext; +import com.facebook.presto.spi.storage.TempStorage; +import com.facebook.presto.spi.storage.TempStorageHandle; +import com.google.common.collect.ImmutableList; +import io.airlift.slice.InputStreamSliceInput; + +import javax.annotation.concurrent.GuardedBy; + +import java.io.IOException; +import java.io.InputStream; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.zip.CRC32; + +import static com.facebook.presto.spark.SparkErrorCode.STORAGE_ERROR; +import static com.facebook.presto.spi.page.PagesSerdeUtil.readSerializedPages; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Collections.shuffle; +import static java.util.Objects.requireNonNull; + +public class PrestoSparkDiskPageInput + implements PrestoSparkPageInput +{ + private static final Logger log = Logger.get(PrestoSparkDiskPageInput.class); + + private final PagesSerde pagesSerde; + private final TempStorage tempStorage; + private final TempDataOperationContext tempDataOperationContext; + private final PrestoSparkBroadcastTableCacheManager prestoSparkBroadcastTableCacheManager; + private final StageId stageId; + private final PlanNodeId planNodeId; + private final List> broadcastTableFilesInfo; + + @GuardedBy("this") + private List> pageIterators; + @GuardedBy("this") + private int currentIteratorIndex; + + public PrestoSparkDiskPageInput( + PagesSerde pagesSerde, + TempStorage tempStorage, + TempDataOperationContext tempDataOperationContext, + PrestoSparkBroadcastTableCacheManager prestoSparkBroadcastTableCacheManager, + StageId stageId, + PlanNodeId planNodeId, + List> broadcastTableFilesInfo) + { + this.pagesSerde = requireNonNull(pagesSerde, "pagesSerde is null"); + this.tempStorage = requireNonNull(tempStorage, "tempStorage is null"); + this.tempDataOperationContext = requireNonNull(tempDataOperationContext, "tempDataOperationContext is null"); + this.prestoSparkBroadcastTableCacheManager = requireNonNull(prestoSparkBroadcastTableCacheManager, "prestoSparkBroadcastTableCacheManager is null"); + this.stageId = requireNonNull(stageId, "stageId is null"); + this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); + this.broadcastTableFilesInfo = requireNonNull(broadcastTableFilesInfo, "broadcastTableFilesInfo is null"); + } + + @Override + public Page getNextPage() + { + Page page = null; + synchronized (this) { + while (page == null) { + if (currentIteratorIndex >= getPageIterators().size()) { + return null; + } + Iterator currentIterator = getPageIterators().get(currentIteratorIndex); + if (currentIterator.hasNext()) { + page = currentIterator.next(); + } + else { + currentIteratorIndex++; + } + } + } + return page; + } + + private List> getPageIterators() + { + if (pageIterators == null) { + pageIterators = getPages(broadcastTableFilesInfo, tempStorage, tempDataOperationContext, prestoSparkBroadcastTableCacheManager, stageId, planNodeId); + } + return pageIterators; + } + + private List> getPages( + List> broadcastTableFilesInfo, + TempStorage tempStorage, + TempDataOperationContext tempDataOperationContext, + PrestoSparkBroadcastTableCacheManager prestoSparkBroadcastTableCacheManager, + StageId stageId, + PlanNodeId planNodeId) + { + // Try to get table from cache + List> pages = prestoSparkBroadcastTableCacheManager.getCachedBroadcastTable(stageId, planNodeId); + if (pages == null) { + pages = broadcastTableFilesInfo.stream() + .map(tableFiles -> { + List serializedPages = loadBroadcastTable(tableFiles, tempStorage, tempDataOperationContext); + return serializedPages.stream().map(serializedPage -> pagesSerde.deserialize(serializedPage)).collect(toImmutableList()); + }).collect(toImmutableList()); + + // Cache deserialized pages + prestoSparkBroadcastTableCacheManager.cache(stageId, planNodeId, pages); + } + + return pages.stream().map(List::iterator).collect(toImmutableList()); + } + + private List loadBroadcastTable( + List broadcastTaskFilesInfo, + TempStorage tempStorage, + TempDataOperationContext tempDataOperationContext) + { + try { + CRC32 checksum = new CRC32(); + ImmutableList.Builder pages = ImmutableList.builder(); + List broadcastTaskFilesInfoCopy = new ArrayList<>(broadcastTaskFilesInfo); + shuffle(broadcastTaskFilesInfoCopy); + for (PrestoSparkTaskOutput taskFileInfo : broadcastTaskFilesInfoCopy) { + checksum.reset(); + PrestoSparkStorageHandle prestoSparkStorageHandle = (PrestoSparkStorageHandle) taskFileInfo; + TempStorageHandle tempStorageHandle = tempStorage.deserialize(prestoSparkStorageHandle.getSerializedStorageHandle()); + log.info("Reading path: " + tempStorageHandle.toString()); + try (InputStream inputStream = tempStorage.open(tempDataOperationContext, tempStorageHandle); + InputStreamSliceInput inputStreamSliceInput = new InputStreamSliceInput(inputStream)) { + Iterator pagesIterator = readSerializedPages(inputStreamSliceInput); + while (pagesIterator.hasNext()) { + SerializedPage page = pagesIterator.next(); + checksum.update(page.getSlice().byteArray(), page.getSlice().byteArrayOffset(), page.getSlice().length()); + pages.add(page); + } + } + + if (checksum.getValue() != prestoSparkStorageHandle.getChecksum()) { + throw new PrestoException(STORAGE_ERROR, "Disk page checksum does not match. " + + "Data seems to be corrupted on disk for file " + tempStorageHandle.toString()); + } + } + return pages.build(); + } + catch (IOException e) { + throw new PrestoException(STORAGE_ERROR, "Unable to read data from disk: ", e); + } + } + + public long getRetainedSizeInBytes() + { + return prestoSparkBroadcastTableCacheManager.getCacheSizeInBytes(); + } +} diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/PrestoSparkRemoteSourceFactory.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/PrestoSparkRemoteSourceFactory.java index e74cef5aa71a2..8f6b3dad0cf6d 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/PrestoSparkRemoteSourceFactory.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/PrestoSparkRemoteSourceFactory.java @@ -17,11 +17,15 @@ import com.facebook.presto.common.block.BlockEncodingManager; import com.facebook.presto.common.block.SortOrder; import com.facebook.presto.common.type.Type; +import com.facebook.presto.execution.StageId; import com.facebook.presto.operator.SourceOperatorFactory; import com.facebook.presto.spark.classloader_interface.PrestoSparkSerializedPage; import com.facebook.presto.spark.classloader_interface.PrestoSparkShuffleStats; +import com.facebook.presto.spark.classloader_interface.PrestoSparkStorageHandle; import com.facebook.presto.spark.execution.PrestoSparkRemoteSourceOperator.SparkRemoteSourceOperatorFactory; import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.storage.TempDataOperationContext; +import com.facebook.presto.spi.storage.TempStorage; import com.facebook.presto.sql.planner.RemoteSourceFactory; import com.google.common.collect.ImmutableMap; import org.apache.spark.util.CollectionAccumulator; @@ -30,8 +34,10 @@ import java.util.List; import java.util.Map; +import static com.facebook.presto.spark.PrestoSparkSessionProperties.isStorageBasedBroadcastJoinEnabled; import static com.facebook.presto.spark.util.PrestoSparkUtils.createPagesSerde; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; public class PrestoSparkRemoteSourceFactory @@ -40,21 +46,36 @@ public class PrestoSparkRemoteSourceFactory private final BlockEncodingManager blockEncodingManager; private final Map> shuffleInputsMap; private final Map>> pageInputsMap; + private final Map>> broadcastInputsMap; private final int taskId; private final CollectionAccumulator shuffleStatsCollector; + private final TempStorage tempStorage; + private final TempDataOperationContext tempDataOperationContext; + private final PrestoSparkBroadcastTableCacheManager prestoSparkBroadcastTableCacheManager; + private final StageId stageId; public PrestoSparkRemoteSourceFactory( BlockEncodingManager blockEncodingManager, Map> shuffleInputsMap, Map>> pageInputsMap, + Map>> broadcastInputsMap, int taskId, - CollectionAccumulator shuffleStatsCollector) + CollectionAccumulator shuffleStatsCollector, + TempStorage tempStorage, + TempDataOperationContext tempDataOperationContext, + PrestoSparkBroadcastTableCacheManager prestoSparkBroadcastTableCacheManager, + StageId stageId) { this.blockEncodingManager = requireNonNull(blockEncodingManager, "blockEncodingManager is null"); this.shuffleInputsMap = ImmutableMap.copyOf(requireNonNull(shuffleInputsMap, "shuffleInputsMap is null")); this.pageInputsMap = ImmutableMap.copyOf(requireNonNull(pageInputsMap, "pageInputs is null")); + this.broadcastInputsMap = ImmutableMap.copyOf(requireNonNull(broadcastInputsMap, "broadcastInputsMap is null")); this.taskId = taskId; this.shuffleStatsCollector = requireNonNull(shuffleStatsCollector, "shuffleStatsCollector is null"); + this.tempDataOperationContext = requireNonNull(tempDataOperationContext, "tempDataOperationContext is null"); + this.tempStorage = requireNonNull(tempStorage, "tempStorage is null"); + this.prestoSparkBroadcastTableCacheManager = requireNonNull(prestoSparkBroadcastTableCacheManager, "prestoSparkBroadcastTableCacheManager is null"); + this.stageId = requireNonNull(stageId, "stageId is null"); } @Override @@ -62,8 +83,34 @@ public SourceOperatorFactory createRemoteSource(Session session, int operatorId, { List shuffleInputs = shuffleInputsMap.get(planNodeId); List> pageInputs = pageInputsMap.get(planNodeId); - checkArgument(shuffleInputs != null || pageInputs != null, "input not found for plan node with id %s", planNodeId); + List> broadcastInputs = broadcastInputsMap.get(planNodeId); + checkArgument(shuffleInputs != null || pageInputs != null || broadcastInputs != null, "input not found for plan node with id %s", planNodeId); checkArgument(shuffleInputs == null || pageInputs == null, "single remote source cannot accept both, shuffle and page inputs"); + if (broadcastInputs != null) { + if (isStorageBasedBroadcastJoinEnabled(session)) { + List> diskPageInputs = + broadcastInputs.stream().map(input -> ((List) input)).collect(toImmutableList()); + return new SparkRemoteSourceOperatorFactory( + operatorId, + planNodeId, + new PrestoSparkDiskPageInput( + createPagesSerde(blockEncodingManager), + tempStorage, + tempDataOperationContext, + prestoSparkBroadcastTableCacheManager, + stageId, + planNodeId, + diskPageInputs)); + } + else { + List> serializedPageInputs = + broadcastInputs.stream().map(input -> ((List) input).iterator()).collect(toImmutableList()); + return new SparkRemoteSourceOperatorFactory( + operatorId, + planNodeId, + new PrestoSparkSerializedPageInput(createPagesSerde(blockEncodingManager), serializedPageInputs)); + } + } if (pageInputs != null) { return new SparkRemoteSourceOperatorFactory( diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/PrestoSparkRemoteSourceOperator.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/PrestoSparkRemoteSourceOperator.java index f4fce0bc2788f..08b40a3c3a198 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/PrestoSparkRemoteSourceOperator.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/PrestoSparkRemoteSourceOperator.java @@ -14,6 +14,7 @@ package com.facebook.presto.spark.execution; import com.facebook.presto.common.Page; +import com.facebook.presto.memory.context.LocalMemoryContext; import com.facebook.presto.metadata.Split; import com.facebook.presto.operator.DriverContext; import com.facebook.presto.operator.OperatorContext; @@ -33,15 +34,19 @@ public class PrestoSparkRemoteSourceOperator { private final PlanNodeId sourceId; private final OperatorContext operatorContext; + private final LocalMemoryContext systemMemoryContext; private final PrestoSparkPageInput pageInput; + private final boolean isFirstOperator; private boolean finished; - public PrestoSparkRemoteSourceOperator(PlanNodeId sourceId, OperatorContext operatorContext, PrestoSparkPageInput pageInput) + public PrestoSparkRemoteSourceOperator(PlanNodeId sourceId, OperatorContext operatorContext, LocalMemoryContext systemMemoryContext, PrestoSparkPageInput pageInput, boolean isFirstOperator) { this.sourceId = requireNonNull(sourceId, "sourceId is null"); this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); + this.systemMemoryContext = requireNonNull(systemMemoryContext, "systemMemoryContext is null"); this.pageInput = requireNonNull(pageInput, "pageInput is null"); + this.isFirstOperator = isFirstOperator; } @Override @@ -70,6 +75,7 @@ public Page getOutput() } Page page = pageInput.getNextPage(); + updateMemoryContext(); if (page == null) { finished = true; return null; @@ -107,12 +113,22 @@ public void noMoreSplits() throw new UnsupportedOperationException(); } + private void updateMemoryContext() + { + // Since the cache is shared, only the first PrestoSparkRemoteSourceOperator should report the cache memory + if (isFirstOperator && pageInput instanceof PrestoSparkDiskPageInput) { + PrestoSparkDiskPageInput diskPageInput = (PrestoSparkDiskPageInput) pageInput; + systemMemoryContext.setBytes(diskPageInput.getRetainedSizeInBytes()); + } + } + public static class SparkRemoteSourceOperatorFactory implements SourceOperatorFactory { private final int operatorId; private final PlanNodeId planNodeId; private final PrestoSparkPageInput pageInput; + private boolean isFirstOperator = true; private boolean closed; @@ -133,10 +149,15 @@ public PlanNodeId getSourceId() public SourceOperator createOperator(DriverContext driverContext) { checkState(!closed, "operator factory is closed"); - return new PrestoSparkRemoteSourceOperator( + OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, PrestoSparkRemoteSourceOperator.class.getSimpleName()); + SourceOperator operator = new PrestoSparkRemoteSourceOperator( planNodeId, - driverContext.addOperatorContext(operatorId, planNodeId, PrestoSparkRemoteSourceOperator.class.getSimpleName()), - pageInput); + operatorContext, + operatorContext.newLocalSystemMemoryContext(PrestoSparkRemoteSourceOperator.class.getSimpleName()), + pageInput, + isFirstOperator); + isFirstOperator = false; + return operator; } @Override diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/PrestoSparkTaskExecutorFactory.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/PrestoSparkTaskExecutorFactory.java index 567c143d8dc9d..e488138d7dbd4 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/PrestoSparkTaskExecutorFactory.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/PrestoSparkTaskExecutorFactory.java @@ -18,6 +18,7 @@ import com.facebook.airlift.stats.TestingGcMonitor; import com.facebook.presto.Session; import com.facebook.presto.common.block.BlockEncodingManager; +import com.facebook.presto.common.io.DataOutput; import com.facebook.presto.event.SplitMonitor; import com.facebook.presto.execution.ExecutionFailureInfo; import com.facebook.presto.execution.ScheduledSplit; @@ -43,6 +44,7 @@ import com.facebook.presto.operator.TaskContext; import com.facebook.presto.operator.TaskStats; import com.facebook.presto.spark.PrestoSparkAuthenticatorProvider; +import com.facebook.presto.spark.PrestoSparkConfig; import com.facebook.presto.spark.PrestoSparkTaskDescriptor; import com.facebook.presto.spark.classloader_interface.IPrestoSparkTaskExecutor; import com.facebook.presto.spark.classloader_interface.IPrestoSparkTaskExecutorFactory; @@ -50,6 +52,7 @@ import com.facebook.presto.spark.classloader_interface.PrestoSparkMutableRow; import com.facebook.presto.spark.classloader_interface.PrestoSparkSerializedPage; import com.facebook.presto.spark.classloader_interface.PrestoSparkShuffleStats; +import com.facebook.presto.spark.classloader_interface.PrestoSparkStorageHandle; import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskInputs; import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskOutput; import com.facebook.presto.spark.classloader_interface.SerializedPrestoSparkTaskDescriptor; @@ -61,8 +64,13 @@ import com.facebook.presto.spark.execution.PrestoSparkRowOutputOperator.PrestoSparkRowOutputFactory; import com.facebook.presto.spi.ConnectorSplit; import com.facebook.presto.spi.memory.MemoryPoolId; +import com.facebook.presto.spi.page.PageDataOutput; import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.security.TokenAuthenticator; +import com.facebook.presto.spi.storage.TempDataOperationContext; +import com.facebook.presto.spi.storage.TempDataSink; +import com.facebook.presto.spi.storage.TempStorage; +import com.facebook.presto.spi.storage.TempStorageHandle; import com.facebook.presto.spiller.NodeSpillConfig; import com.facebook.presto.spiller.SpillSpaceTracker; import com.facebook.presto.sql.planner.LocalExecutionPlanner; @@ -72,10 +80,12 @@ import com.facebook.presto.sql.planner.plan.PlanFragmentId; import com.facebook.presto.sql.planner.plan.RemoteSourceNode; import com.facebook.presto.sql.planner.planPrinter.PlanPrinter; +import com.facebook.presto.storage.TempStorageManager; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import io.airlift.slice.Slice; import io.airlift.units.DataSize; import org.apache.spark.broadcast.Broadcast; import org.apache.spark.util.CollectionAccumulator; @@ -86,6 +96,8 @@ import javax.inject.Inject; +import java.io.IOException; +import java.io.UncheckedIOException; import java.net.URI; import java.util.ArrayList; import java.util.List; @@ -98,6 +110,7 @@ import java.util.concurrent.Executor; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ScheduledExecutorService; +import java.util.zip.CRC32; import static com.facebook.presto.SystemSessionProperties.getHashPartitionCount; import static com.facebook.presto.SystemSessionProperties.getQueryMaxBroadcastMemory; @@ -108,6 +121,7 @@ import static com.facebook.presto.execution.buffer.BufferState.FINISHED; import static com.facebook.presto.metadata.MetadataUpdates.DEFAULT_METADATA_UPDATES; import static com.facebook.presto.spark.PrestoSparkSessionProperties.getShuffleOutputTargetAverageRowSize; +import static com.facebook.presto.spark.PrestoSparkSessionProperties.getStorageBasedBroadcastJoinWriteBufferSize; import static com.facebook.presto.spark.classloader_interface.PrestoSparkShuffleStats.Operation.WRITE; import static com.facebook.presto.spark.util.PrestoSparkUtils.compress; import static com.facebook.presto.spark.util.PrestoSparkUtils.decompress; @@ -154,6 +168,9 @@ public class PrestoSparkTaskExecutorFactory private final boolean perOperatorAllocationTrackingEnabled; private final boolean allocationTrackingEnabled; + private final TempStorageManager tempStorageManager; + private final PrestoSparkBroadcastTableCacheManager prestoSparkBroadcastTableCacheManager; + private final String storageBasedBroadcastJoinStorage; @Inject public PrestoSparkTaskExecutorFactory( @@ -175,7 +192,10 @@ public PrestoSparkTaskExecutorFactory( ObjectMapper objectMapper, TaskManagerConfig taskManagerConfig, NodeMemoryConfig nodeMemoryConfig, - NodeSpillConfig nodeSpillConfig) + NodeSpillConfig nodeSpillConfig, + TempStorageManager tempStorageManager, + PrestoSparkBroadcastTableCacheManager prestoSparkBroadcastTableCacheManager, + PrestoSparkConfig prestoSparkConfig) { this( sessionPropertyManager, @@ -198,7 +218,10 @@ public PrestoSparkTaskExecutorFactory( requireNonNull(taskManagerConfig, "taskManagerConfig is null").isPerOperatorCpuTimerEnabled(), requireNonNull(taskManagerConfig, "taskManagerConfig is null").isTaskCpuTimerEnabled(), requireNonNull(taskManagerConfig, "taskManagerConfig is null").isPerOperatorAllocationTrackingEnabled(), - requireNonNull(taskManagerConfig, "taskManagerConfig is null").isTaskAllocationTrackingEnabled()); + requireNonNull(taskManagerConfig, "taskManagerConfig is null").isTaskAllocationTrackingEnabled(), + tempStorageManager, + requireNonNull(prestoSparkConfig, "prestoSparkConfig is null").getStorageBasedBroadcastJoinStorage(), + prestoSparkBroadcastTableCacheManager); } public PrestoSparkTaskExecutorFactory( @@ -222,7 +245,10 @@ public PrestoSparkTaskExecutorFactory( boolean perOperatorCpuTimerEnabled, boolean cpuTimerEnabled, boolean perOperatorAllocationTrackingEnabled, - boolean allocationTrackingEnabled) + boolean allocationTrackingEnabled, + TempStorageManager tempStorageManager, + String storageBasedBroadcastJoinStorage, + PrestoSparkBroadcastTableCacheManager prestoSparkBroadcastTableCacheManager) { this.sessionPropertyManager = requireNonNull(sessionPropertyManager, "sessionPropertyManager is null"); this.blockEncodingManager = requireNonNull(blockEncodingManager, "blockEncodingManager is null"); @@ -246,6 +272,9 @@ public PrestoSparkTaskExecutorFactory( this.cpuTimerEnabled = cpuTimerEnabled; this.perOperatorAllocationTrackingEnabled = perOperatorAllocationTrackingEnabled; this.allocationTrackingEnabled = allocationTrackingEnabled; + this.tempStorageManager = requireNonNull(tempStorageManager, "tempStorageManager is null"); + this.storageBasedBroadcastJoinStorage = requireNonNull(storageBasedBroadcastJoinStorage, "storageBasedBroadcastJoinStorage is null"); + this.prestoSparkBroadcastTableCacheManager = requireNonNull(prestoSparkBroadcastTableCacheManager, "prestoSparkBroadcastTableCacheManager is null"); } @Override @@ -295,6 +324,11 @@ public IPrestoSparkTaskExecutor doCreate( extraAuthenticators.build()); PlanFragment fragment = taskDescriptor.getFragment(); StageId stageId = new StageId(session.getQueryId(), fragment.getId().getId()); + + // Clear the cache if the cache does not have broadcast table for current stageId. + // We will only cache 1 HT at any time. If the stageId changes, we will drop the old cached HT + prestoSparkBroadcastTableCacheManager.removeCachedTablesForStagesOtherThan(stageId); + // TODO: include attemptId in taskId TaskId taskId = new TaskId(new StageExecutionId(stageId, 0), partitionId); @@ -343,13 +377,18 @@ public IPrestoSparkTaskExecutor doCreate( ImmutableMap.Builder> shuffleInputs = ImmutableMap.builder(); ImmutableMap.Builder>> pageInputs = ImmutableMap.builder(); + ImmutableMap.Builder>> broadcastInputs = ImmutableMap.builder(); for (RemoteSourceNode remoteSource : fragment.getRemoteSourceNodes()) { List remoteSourceRowInputs = new ArrayList<>(); List> remoteSourcePageInputs = new ArrayList<>(); + List> broadcastInputsList = new ArrayList<>(); for (PlanFragmentId sourceFragmentId : remoteSource.getSourceFragmentIds()) { - Iterator> shuffleInput = inputs.getShuffleInputs().get(sourceFragmentId.toString()); - Broadcast> broadcastInput = inputs.getBroadcastInputs().get(sourceFragmentId.toString()); - List inMemoryInput = inputs.getInMemoryInputs().get(sourceFragmentId.toString()); + Iterator> shuffleInput = + (Iterator>) inputs.getShuffleInputs().get(sourceFragmentId.toString()); + Broadcast> broadcastInput = + (Broadcast>) inputs.getBroadcastInputs().get(sourceFragmentId.toString()); + List inMemoryInput = + (List) inputs.getInMemoryInputs().get(sourceFragmentId.toString()); if (shuffleInput != null) { checkArgument(broadcastInput == null, "single remote source is not expected to accept different kind of inputs"); @@ -364,7 +403,7 @@ public IPrestoSparkTaskExecutor doCreate( // NullifyingIterator removes element from the list upon return // This allows GC to gradually reclaim memory // remoteSourcePageInputs.add(getNullifyingIterator(broadcastInput.value())); - remoteSourcePageInputs.add(broadcastInput.value().iterator()); + broadcastInputsList.add(broadcastInput.value()); continue; } @@ -381,6 +420,9 @@ public IPrestoSparkTaskExecutor doCreate( if (!remoteSourcePageInputs.isEmpty()) { pageInputs.put(remoteSource.getId(), remoteSourcePageInputs); } + if (!broadcastInputsList.isEmpty()) { + broadcastInputs.put(remoteSource.getId(), broadcastInputsList); + } } OutputBufferMemoryManager memoryManager = new OutputBufferMemoryManager( @@ -398,12 +440,23 @@ public IPrestoSparkTaskExecutor doCreate( false, OptionalInt.empty())); } + + TempDataOperationContext tempDataOperationContext = new TempDataOperationContext( + session.getSource(), + session.getQueryId().getId(), + session.getClientInfo(), + session.getIdentity()); + TempStorage tempStorage = tempStorageManager.getTempStorage(storageBasedBroadcastJoinStorage); + Output output = configureOutput( outputType, blockEncodingManager, memoryManager, getShuffleOutputTargetAverageRowSize(session), - preDeterminedPartition); + preDeterminedPartition, + tempStorage, + tempDataOperationContext, + getStorageBasedBroadcastJoinWriteBufferSize(session)); PrestoSparkOutputBuffer outputBuffer = output.getOutputBuffer(); LocalExecutionPlan localExecutionPlan = localExecutionPlanner.plan( @@ -417,8 +470,13 @@ public IPrestoSparkTaskExecutor doCreate( blockEncodingManager, shuffleInputs.build(), pageInputs.build(), + broadcastInputs.build(), partitionId, - shuffleStatsCollector), + shuffleStatsCollector, + tempStorage, + tempDataOperationContext, + prestoSparkBroadcastTableCacheManager, + stageId), taskDescriptor.getTableWriteInfo(), true); @@ -448,7 +506,9 @@ public IPrestoSparkTaskExecutor doCreate( shuffleStatsCollector, executionExceptionFactory, output.getOutputBufferType(), - outputBuffer); + outputBuffer, + tempStorage, + tempDataOperationContext); } private static OptionalLong computeAllSplitsSize(List taskSources) @@ -484,7 +544,10 @@ private static Output configureOutput( BlockEncodingManager blockEncodingManager, OutputBufferMemoryManager memoryManager, DataSize targetAverageRowSize, - Optional preDeterminedPartition) + Optional preDeterminedPartition, + TempStorage tempStorage, + TempDataOperationContext tempDataOperationContext, + DataSize writeBufferSize) { if (outputType.equals(PrestoSparkMutableRow.class)) { PrestoSparkOutputBuffer outputBuffer = new PrestoSparkOutputBuffer<>(memoryManager); @@ -498,6 +561,12 @@ else if (outputType.equals(PrestoSparkSerializedPage.class)) { OutputSupplier outputSupplier = (OutputSupplier) new PageOutputSupplier(outputBuffer); return new Output<>(OutputBufferType.SPARK_PAGE_OUTPUT_BUFFER, outputBuffer, outputFactory, outputSupplier); } + else if (outputType.equals(PrestoSparkStorageHandle.class)) { + PrestoSparkOutputBuffer outputBuffer = new PrestoSparkOutputBuffer<>(memoryManager); + OutputFactory outputFactory = new PrestoSparkPageOutputFactory(outputBuffer, blockEncodingManager); + OutputSupplier outputSupplier = (OutputSupplier) new DiskPageOutputSupplier(outputBuffer, tempStorage, tempDataOperationContext, writeBufferSize); + return new Output<>(OutputBufferType.SPARK_DISK_PAGE_OUTPUT_BUFFER, outputBuffer, outputFactory, outputSupplier); + } else { throw new IllegalArgumentException("Unexpected output type: " + outputType.getName()); } @@ -516,6 +585,8 @@ private static class PrestoSparkTaskExecutor private final PrestoSparkExecutionExceptionFactory executionExceptionFactory; private final OutputBufferType outputBufferType; private final PrestoSparkOutputBuffer outputBuffer; + private final TempStorage tempStorage; + private final TempDataOperationContext tempDataOperationContext; private final UUID taskInstanceId = randomUUID(); @@ -535,7 +606,9 @@ private PrestoSparkTaskExecutor( CollectionAccumulator shuffleStatsCollector, PrestoSparkExecutionExceptionFactory executionExceptionFactory, OutputBufferType outputBufferType, - PrestoSparkOutputBuffer outputBuffer) + PrestoSparkOutputBuffer outputBuffer, + TempStorage tempStorage, + TempDataOperationContext tempDataOperationContext) { this.taskContext = requireNonNull(taskContext, "taskContext is null"); this.taskStateMachine = requireNonNull(taskStateMachine, "taskStateMachine is null"); @@ -546,6 +619,8 @@ private PrestoSparkTaskExecutor( this.executionExceptionFactory = requireNonNull(executionExceptionFactory, "executionExceptionFactory is null"); this.outputBufferType = requireNonNull(outputBufferType, "outputBufferType is null"); this.outputBuffer = requireNonNull(outputBuffer, "outputBuffer is null"); + this.tempStorage = requireNonNull(tempStorage, "tempStorage is null"); + this.tempDataOperationContext = requireNonNull(tempDataOperationContext, "tempDataOperationContext is null"); } @Override @@ -629,6 +704,22 @@ private Tuple2 doComputeNext() } Throwable failure = getFirst(failures, null); + // Delete the storage file, if task is not successful + if (outputSupplier instanceof DiskPageOutputSupplier) { + PrestoSparkStorageHandle sparkStorageHandle = (PrestoSparkStorageHandle) output._2; + TempStorageHandle tempStorageHandle = tempStorage.deserialize(sparkStorageHandle.getSerializedStorageHandle()); + try { + tempStorage.remove(tempDataOperationContext, tempStorageHandle); + log.info("Removed broadcast spill file: " + tempStorageHandle.toString()); + } + catch (IOException e) { + // self suppression is not allowed + if (e != failure) { + failure.addSuppressed(e); + } + } + } + propagateIfPossible(failure, Error.class); propagateIfPossible(failure, RuntimeException.class); propagateIfPossible(failure, InterruptedException.class); @@ -823,9 +914,131 @@ public long getTimeSpentWaitingForOutputInMillis() } } + private static class DiskPageOutputSupplier + implements OutputSupplier + { + private static final MutablePartitionId DEFAULT_PARTITION = new MutablePartitionId(); + + private final PrestoSparkOutputBuffer outputBuffer; + private final TempStorage tempStorage; + private final TempDataOperationContext tempDataOperationContext; + private final long writeBufferSizeInBytes; + + private TempDataSink tempDataSink; + private long timeSpentWaitingForOutputInMillis; + + private DiskPageOutputSupplier(PrestoSparkOutputBuffer outputBuffer, + TempStorage tempStorage, + TempDataOperationContext tempDataOperationContext, + DataSize writeBufferSize) + { + this.outputBuffer = requireNonNull(outputBuffer, "outputBuffer is null"); + this.tempStorage = requireNonNull(tempStorage, "tempStorage is null"); + this.tempDataOperationContext = requireNonNull(tempDataOperationContext, "tempDataOperationContext is null"); + this.writeBufferSizeInBytes = requireNonNull(writeBufferSize, "writeBufferSize is null").toBytes(); + } + + @Override + public Tuple2 getNext() + throws InterruptedException + { + long start = System.currentTimeMillis(); + PrestoSparkBufferedSerializedPage page = outputBuffer.get(); + if (page == null) { + return null; + } + + long compressedBroadcastSizeInBytes = 0; + long uncompressedBroadcastSizeInBytes = 0; + int positionCount = 0; + CRC32 checksum = new CRC32(); + TempStorageHandle tempStorageHandle; + IOException ioException = null; + try { + this.tempDataSink = tempStorage.create(tempDataOperationContext); + List bufferedPages = new ArrayList<>(); + long bufferedBytes = 0; + + while (page != null) { + PageDataOutput pageDataOutput = new PageDataOutput(page.getSerializedPage()); + long writtenSize = pageDataOutput.size(); + + if ((writeBufferSizeInBytes - bufferedBytes) < writtenSize) { + tempDataSink.write(bufferedPages); + bufferedPages.clear(); + bufferedBytes = 0; + } + + bufferedPages.add(pageDataOutput); + bufferedBytes += writtenSize; + compressedBroadcastSizeInBytes += page.getSerializedPage().getSizeInBytes(); + uncompressedBroadcastSizeInBytes += page.getSerializedPage().getUncompressedSizeInBytes(); + positionCount += page.getPositionCount(); + Slice slice = page.getSerializedPage().getSlice(); + checksum.update(slice.byteArray(), slice.byteArrayOffset(), slice.length()); + page = outputBuffer.get(); + } + + if (!bufferedPages.isEmpty()) { + tempDataSink.write(bufferedPages); + bufferedPages.clear(); + } + + tempStorageHandle = tempDataSink.commit(); + log.info("Created broadcast spill file: " + tempStorageHandle.toString()); + PrestoSparkStorageHandle prestoSparkStorageHandle = + new PrestoSparkStorageHandle( + tempStorage.serializeHandle(tempStorageHandle), + uncompressedBroadcastSizeInBytes, + compressedBroadcastSizeInBytes, + checksum.getValue(), + positionCount); + long end = System.currentTimeMillis(); + timeSpentWaitingForOutputInMillis += (end - start); + return new Tuple2<>(DEFAULT_PARTITION, prestoSparkStorageHandle); + } + catch (IOException e) { + if (ioException == null) { + ioException = e; + } + try { + tempDataSink.rollback(); + } + catch (IOException exception) { + if (ioException != exception) { + ioException.addSuppressed(exception); + } + } + } + finally { + try { + tempDataSink.close(); + } + catch (IOException e) { + if (ioException == null) { + ioException = e; + } + else if (ioException != e) { + ioException.addSuppressed(e); + } + throw new UncheckedIOException("Unable to dump data to disk: ", ioException); + } + } + + throw new UncheckedIOException("Unable to dump data to disk: ", ioException); + } + + @Override + public long getTimeSpentWaitingForOutputInMillis() + { + return timeSpentWaitingForOutputInMillis; + } + } + private enum OutputBufferType { SPARK_ROW_OUTPUT_BUFFER, SPARK_PAGE_OUTPUT_BUFFER, + SPARK_DISK_PAGE_OUTPUT_BUFFER, } } diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/PrestoSparkRddFactory.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/PrestoSparkRddFactory.java index a1dcfb19c8ab6..b99e599ac8908 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/PrestoSparkRddFactory.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/PrestoSparkRddFactory.java @@ -24,7 +24,6 @@ import com.facebook.presto.spark.classloader_interface.MutablePartitionId; import com.facebook.presto.spark.classloader_interface.PrestoSparkMutableRow; import com.facebook.presto.spark.classloader_interface.PrestoSparkPartitioner; -import com.facebook.presto.spark.classloader_interface.PrestoSparkSerializedPage; import com.facebook.presto.spark.classloader_interface.PrestoSparkShuffleSerializer; import com.facebook.presto.spark.classloader_interface.PrestoSparkShuffleStats; import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskExecutorFactoryProvider; @@ -140,7 +139,7 @@ public JavaPairRDD crea Session session, PlanFragment fragment, Map> rddInputs, - Map>> broadcastInputs, + Map>> broadcastInputs, PrestoSparkTaskExecutorFactoryProvider executorFactoryProvider, CollectionAccumulator taskInfoCollector, CollectionAccumulator shuffleStatsCollector, @@ -260,7 +259,7 @@ private JavaPairRDD cre CollectionAccumulator shuffleStatsCollector, TableWriteInfo tableWriteInfo, Map> rddInputs, - Map>> broadcastInputs, + Map>> broadcastInputs, Class outputType) { checkInputs(fragment.getRemoteSourceNodes(), rddInputs, broadcastInputs); @@ -545,16 +544,16 @@ private static List findTableScanNodes(PlanNode node) .findAll(); } - private static Map>> toTaskProcessorBroadcastInputs(Map>> broadcastInputs) + private static Map>> toTaskProcessorBroadcastInputs(Map>> broadcastInputs) { return broadcastInputs.entrySet().stream() .collect(toImmutableMap(entry -> entry.getKey().toString(), Map.Entry::getValue)); } - private static void checkInputs( + private static void checkInputs( List remoteSources, Map> rddInputs, - Map>> broadcastInputs) + Map>> broadcastInputs) { Set expectedInputs = remoteSources.stream() .map(RemoteSourceNode::getSourceFragmentIds) diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/util/PrestoSparkUtils.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/util/PrestoSparkUtils.java index 0db743c4d4fce..4e23ac52de971 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/util/PrestoSparkUtils.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/util/PrestoSparkUtils.java @@ -20,8 +20,11 @@ import com.facebook.presto.spi.page.PagesSerde; import com.facebook.presto.spi.page.SerializedPage; import com.github.luben.zstd.Zstd; +import com.google.common.util.concurrent.UncheckedExecutionException; import io.airlift.slice.Slice; import io.airlift.slice.Slices; +import org.apache.spark.SparkException; +import org.apache.spark.api.java.JavaFutureAction; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; @@ -29,13 +32,18 @@ import java.io.UncheckedIOException; import java.nio.ByteBuffer; import java.util.Optional; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.zip.DeflaterInputStream; import java.util.zip.InflaterOutputStream; import static com.facebook.presto.common.block.BlockUtil.compactArray; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Throwables.propagateIfPossible; import static com.google.common.io.ByteStreams.toByteArray; import static java.lang.Math.toIntExact; +import static java.util.concurrent.TimeUnit.MILLISECONDS; public class PrestoSparkUtils { @@ -160,4 +168,53 @@ public void decompress(ByteBuffer input, ByteBuffer output) } }; } + + public static long computeNextTimeout(long queryCompletionDeadline) + throws TimeoutException + { + long timeout = queryCompletionDeadline - System.currentTimeMillis(); + if (timeout <= 0) { + throw new TimeoutException(); + } + return timeout; + } + + public static T getActionResultWithTimeout(JavaFutureAction action, long timeout, TimeUnit timeUnit) + throws SparkException, TimeoutException + { + long deadline = System.currentTimeMillis() + timeUnit.toMillis(timeout); + try { + while (true) { + long nextTimeoutInMillis = deadline - System.currentTimeMillis(); + if (nextTimeoutInMillis <= 0) { + throw new TimeoutException(); + } + try { + return action.get(nextTimeoutInMillis, MILLISECONDS); + } + catch (TimeoutException e) { + // guard against spurious wakeup + if (deadline - System.currentTimeMillis() <= 0) { + throw e; + } + } + } + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + catch (ExecutionException e) { + propagateIfPossible(e.getCause(), SparkException.class); + propagateIfPossible(e.getCause(), RuntimeException.class); + + // this should never happen + throw new UncheckedExecutionException(e.getCause()); + } + finally { + if (!action.isDone()) { + action.cancel(true); + } + } + } } diff --git a/presto-spark-base/src/test/java/com/facebook/presto/spark/TestPrestoSparkConfig.java b/presto-spark-base/src/test/java/com/facebook/presto/spark/TestPrestoSparkConfig.java index 8c9d95b8384ad..593fa76785744 100644 --- a/presto-spark-base/src/test/java/com/facebook/presto/spark/TestPrestoSparkConfig.java +++ b/presto-spark-base/src/test/java/com/facebook/presto/spark/TestPrestoSparkConfig.java @@ -24,6 +24,7 @@ import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; import static io.airlift.units.DataSize.Unit.GIGABYTE; import static io.airlift.units.DataSize.Unit.KILOBYTE; +import static io.airlift.units.DataSize.Unit.MEGABYTE; public class TestPrestoSparkConfig { @@ -36,7 +37,10 @@ public void testDefaults() .setMinSparkInputPartitionCountForAutoTune(100) .setMaxSparkInputPartitionCountForAutoTune(1000) .setMaxSplitsDataSizePerSparkPartition(new DataSize(2, GIGABYTE)) - .setShuffleOutputTargetAverageRowSize(new DataSize(1, KILOBYTE))); + .setShuffleOutputTargetAverageRowSize(new DataSize(1, KILOBYTE)) + .setStorageBasedBroadcastJoinEnabled(false) + .setStorageBasedBroadcastJoinStorage("local") + .setStorageBasedBroadcastJoinWriteBufferSize(new DataSize(24, MEGABYTE))); } @Test @@ -49,6 +53,9 @@ public void testExplicitPropertyMappings() .put("spark.max-spark-input-partition-count-for-auto-tune", "2000") .put("spark.max-splits-data-size-per-partition", "4GB") .put("spark.shuffle-output-target-average-row-size", "10kB") + .put("spark.storage-based-broadcast-join-enabled", "true") + .put("spark.storage-based-broadcast-join-storage", "tempfs") + .put("spark.storage-based-broadcast-join-write-buffer-size", "4MB") .build(); PrestoSparkConfig expected = new PrestoSparkConfig() .setSparkPartitionCountAutoTuneEnabled(false) @@ -56,7 +63,10 @@ public void testExplicitPropertyMappings() .setMinSparkInputPartitionCountForAutoTune(200) .setMaxSparkInputPartitionCountForAutoTune(2000) .setMaxSplitsDataSizePerSparkPartition(new DataSize(4, GIGABYTE)) - .setShuffleOutputTargetAverageRowSize(new DataSize(10, KILOBYTE)); + .setShuffleOutputTargetAverageRowSize(new DataSize(10, KILOBYTE)) + .setStorageBasedBroadcastJoinEnabled(true) + .setStorageBasedBroadcastJoinStorage("tempfs") + .setStorageBasedBroadcastJoinWriteBufferSize(new DataSize(4, MEGABYTE)); assertFullMapping(properties, expected); } } diff --git a/presto-spark-base/src/test/java/com/facebook/presto/spark/TestPrestoSparkQueryRunner.java b/presto-spark-base/src/test/java/com/facebook/presto/spark/TestPrestoSparkQueryRunner.java index 9dfd527209233..09e78b71eb05c 100644 --- a/presto-spark-base/src/test/java/com/facebook/presto/spark/TestPrestoSparkQueryRunner.java +++ b/presto-spark-base/src/test/java/com/facebook/presto/spark/TestPrestoSparkQueryRunner.java @@ -18,6 +18,8 @@ import com.facebook.presto.tests.AbstractTestQueryFramework; import org.testng.annotations.Test; +import static com.facebook.presto.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE; +import static com.facebook.presto.spark.PrestoSparkSessionProperties.STORAGE_BASED_BROADCAST_JOIN_ENABLED; import static com.facebook.presto.testing.assertions.Assert.assertEquals; import static org.assertj.core.api.Assertions.assertThat; @@ -586,6 +588,28 @@ public void testTimeouts() assertQueryFails(queryMaxExecutionTimeLimitSession, longRunningCrossJoin, "Query exceeded maximum time limit of 2.00s"); } + @Test + public void testDiskBasedBroadcastJoin() + { + Session session = Session.builder(getSession()) + .setSystemProperty(JOIN_DISTRIBUTION_TYPE, "BROADCAST") + .setSystemProperty(STORAGE_BASED_BROADCAST_JOIN_ENABLED, "true") + .build(); + + assertQuery(session, + "select * from lineitem l join orders o on l.orderkey = o.orderkey"); + + assertQuery(session, + "select l.orderkey from lineitem l join orders o on l.orderkey = o.orderkey " + + "Union all " + + "SELECT m.nationkey FROM nation m JOIN nation n ON m.nationkey = n.nationkey"); + + assertQuery(session, + "SELECT o.custkey, l.orderkey " + + "FROM (SELECT * FROM lineitem WHERE linenumber = 4) l " + + "CROSS JOIN (SELECT * FROM orders WHERE orderkey = 5) o"); + } + private void assertBucketedQuery(String sql) { assertQuery(sql, sql.replaceAll("_bucketed", "")); diff --git a/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkConfInitializer.java b/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkConfInitializer.java index d6bdf5649d22e..f7a9828cfafd8 100644 --- a/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkConfInitializer.java +++ b/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkConfInitializer.java @@ -39,7 +39,8 @@ private static void registerKryoClasses(SparkConf sparkConf) PrestoSparkSerializedPage.class, SerializedPrestoSparkTaskDescriptor.class, SerializedTaskInfo.class, - PrestoSparkShuffleStats.class + PrestoSparkShuffleStats.class, + PrestoSparkStorageHandle.class }); } diff --git a/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkStorageHandle.java b/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkStorageHandle.java new file mode 100644 index 0000000000000..090a374c239dc --- /dev/null +++ b/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkStorageHandle.java @@ -0,0 +1,74 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spark.classloader_interface; + +import java.io.Serializable; + +import static java.util.Objects.requireNonNull; + +public class PrestoSparkStorageHandle + implements Serializable, PrestoSparkTaskOutput +{ + private final byte[] serializedStorageHandle; + private final long uncompressedSizeInBytes; + private final long compressedSizeInBytes; + private final long checksum; + private final int positionCount; + + public PrestoSparkStorageHandle( + byte[] serializedStorageHandle, + long uncompressedSizeInBytes, + long compressedSizeInBytes, + long checksum, + int positionCount) + { + this.serializedStorageHandle = requireNonNull(serializedStorageHandle, "serializedStorageHandle is null"); + this.uncompressedSizeInBytes = uncompressedSizeInBytes; + this.compressedSizeInBytes = compressedSizeInBytes; + this.checksum = requireNonNull(checksum, "checksum is null"); + this.positionCount = positionCount; + } + + public long getUncompressedSizeInBytes() + { + return uncompressedSizeInBytes; + } + + public long getCompressedSizeInBytes() + { + return compressedSizeInBytes; + } + + public byte[] getSerializedStorageHandle() + { + return serializedStorageHandle; + } + + public long getChecksum() + { + return checksum; + } + + @Override + public long getPositionCount() + { + return positionCount; + } + + @Override + public long getSize() + { + return uncompressedSizeInBytes; + } +} diff --git a/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkTaskInputs.java b/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkTaskInputs.java index 11827d8291b44..ba75659ec3784 100644 --- a/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkTaskInputs.java +++ b/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkTaskInputs.java @@ -24,17 +24,17 @@ import static java.util.Collections.unmodifiableMap; import static java.util.Objects.requireNonNull; -public class PrestoSparkTaskInputs +public class PrestoSparkTaskInputs { // fragmentId -> Iterator<[partitionId, page]> private final Map>> shuffleInputs; - private final Map>> broadcastInputs; + private final Map>> broadcastInputs; // For the COORDINATOR_ONLY fragment we first collect the inputs on the Driver private final Map> inMemoryInputs; public PrestoSparkTaskInputs( Map>> shuffleInputs, - Map>> broadcastInputs, + Map>> broadcastInputs, Map> inMemoryInputs) { this.shuffleInputs = unmodifiableMap(new HashMap<>(requireNonNull(shuffleInputs, "shuffleInputs is null"))); @@ -47,7 +47,7 @@ public Map>> return shuffleInputs; } - public Map>> getBroadcastInputs() + public Map>> getBroadcastInputs() { return broadcastInputs; } diff --git a/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkTaskProcessor.java b/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkTaskProcessor.java index 0d1dd7f6f2520..598cb3adfaf2a 100644 --- a/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkTaskProcessor.java +++ b/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkTaskProcessor.java @@ -35,7 +35,7 @@ public class PrestoSparkTaskProcessor private final CollectionAccumulator taskInfoCollector; private final CollectionAccumulator shuffleStatsCollector; // fragmentId -> Broadcast - private final Map>> broadcastInputs; + private final Map>> broadcastInputs; private final Class outputType; public PrestoSparkTaskProcessor( @@ -43,7 +43,7 @@ public PrestoSparkTaskProcessor( SerializedPrestoSparkTaskDescriptor serializedTaskDescriptor, CollectionAccumulator taskInfoCollector, CollectionAccumulator shuffleStatsCollector, - Map>> broadcastInputs, + Map>> broadcastInputs, Class outputType) { this.taskExecutorFactoryProvider = requireNonNull(taskExecutorFactoryProvider, "taskExecutorFactoryProvider is null"); diff --git a/presto-spark-common/src/main/java/com/facebook/presto/spark/SparkErrorCode.java b/presto-spark-common/src/main/java/com/facebook/presto/spark/SparkErrorCode.java index c1c9878d7fc50..48e0584620712 100644 --- a/presto-spark-common/src/main/java/com/facebook/presto/spark/SparkErrorCode.java +++ b/presto-spark-common/src/main/java/com/facebook/presto/spark/SparkErrorCode.java @@ -17,6 +17,7 @@ import com.facebook.presto.spi.ErrorCodeSupplier; import com.facebook.presto.spi.ErrorType; +import static com.facebook.presto.spi.ErrorType.EXTERNAL; import static com.facebook.presto.spi.ErrorType.INSUFFICIENT_RESOURCES; import static com.facebook.presto.spi.ErrorType.INTERNAL_ERROR; @@ -26,7 +27,9 @@ public enum SparkErrorCode GENERIC_SPARK_ERROR(0, INTERNAL_ERROR), SPARK_EXECUTOR_OOM(1, INTERNAL_ERROR), SPARK_EXECUTOR_LOST(2, INTERNAL_ERROR), - EXCEEDED_SPARK_DRIVER_MAX_RESULT_SIZE(3, INSUFFICIENT_RESOURCES) + EXCEEDED_SPARK_DRIVER_MAX_RESULT_SIZE(3, INSUFFICIENT_RESOURCES), + UNSUPPORTED_STORAGE_TYPE(4, INTERNAL_ERROR), + STORAGE_ERROR(5, EXTERNAL) /**/; private final ErrorCode errorCode; diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/storage/StorageCapabilities.java b/presto-spi/src/main/java/com/facebook/presto/spi/storage/StorageCapabilities.java new file mode 100644 index 0000000000000..28c22291c039b --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/storage/StorageCapabilities.java @@ -0,0 +1,19 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.storage; + +public enum StorageCapabilities +{ + REMOTELY_ACCESSIBLE, +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/storage/TempStorage.java b/presto-spi/src/main/java/com/facebook/presto/spi/storage/TempStorage.java index eea8f1c35f092..43d96aa7a1f0e 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/storage/TempStorage.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/storage/TempStorage.java @@ -15,6 +15,7 @@ import java.io.IOException; import java.io.InputStream; +import java.util.List; public interface TempStorage { @@ -26,4 +27,10 @@ InputStream open(TempDataOperationContext context, TempStorageHandle handle) void remove(TempDataOperationContext context, TempStorageHandle handle) throws IOException; + + byte[] serializeHandle(TempStorageHandle storageHandle); + + TempStorageHandle deserialize(byte[] serializedStorageHandle); + + List getStorageCapabilities(); }