diff --git a/pom.xml b/pom.xml
index 468211c5995f0..f0faacf9d5a56 100644
--- a/pom.xml
+++ b/pom.xml
@@ -91,6 +91,8 @@
6.0.0
17.0.0
+ 2
+
+ org.codehaus.plexus:plexus-utils
+ com.google.guava:guava
+ com.fasterxml.jackson.core:jackson-annotations
+ com.fasterxml.jackson.core:jackson-core
+ com.fasterxml.jackson.core:jackson-databind
+
+
+
+
+
+
+
+ org.basepom.maven
+ duplicate-finder-maven-plugin
+
+ true
+
+ com.github.benmanes.caffeine.*
+
+ META-INF.versions.9.module-info
+
+ META-INF.versions.11.module-info
+
+ META-INF.versions.9.org.apache.lucene.*
+
+
+
+
+
+
+
+
+
+
diff --git a/presto-spark-base/pom.xml b/presto-spark-base/pom.xml
index 0501a9f16107f..8f3dd0f91a66d 100644
--- a/presto-spark-base/pom.xml
+++ b/presto-spark-base/pom.xml
@@ -16,6 +16,7 @@
9.4.55.v20240627
4.12.0
3.9.1
+ 2
@@ -51,7 +52,12 @@
com.facebook.presto
presto-spark-classloader-interface
- ${project.version}
+ provided
+
+
+
+ com.facebook.presto
+ presto-spark-classloader-spark${dep.pos.classloader.module-name.suffix}
provided
@@ -459,6 +465,10 @@
javax.ws.rs:javax.ws.rs-api
javax.servlet:javax.servlet-api
+ com.google.code.findbugs:jsr305
+ javax.inject:javax.inject
+ javax.annotation:javax.annotation-api
+ org.apache.httpcomponents:httpcore
@@ -524,5 +534,106 @@
+
+ spark3
+
+
+
+ spark-version
+ 3
+
+
+
+
+ 3
+
+
+
+
+ com.facebook.presto.spark
+ spark-core
+ 3.4.1-1
+ provided
+
+
+
+ com.google.inject
+ guice
+ provided
+
+
+
+ jakarta.annotation
+ jakarta.annotation-api
+ provided
+
+
+
+ com.google.errorprone
+ error_prone_annotations
+ provided
+
+
+
+ org.weakref
+ jmxutils
+ provided
+
+
+
+ jakarta.validation
+ jakarta.validation-api
+ provided
+
+
+
+ org.scala-lang
+ scala-library
+ 2.13.8
+ provided
+
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-dependency-plugin
+
+
+ org.glassfish.hk2.external:jakarta.inject
+ org.apache.hadoop:hadoop-client-api
+ jakarta.annotation:jakarta.annotation-api
+ jakarta.validation:jakarta.validation-api
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-compiler-plugin
+
+
+
+ **/TestPrestoSparkExecutionExceptionFactory.java
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-surefire-plugin
+
+
+
+ **/TestPrestoSparkExecutionExceptionFactory.java
+
+
+
+
+
+
+
diff --git a/presto-spark-classloader-interface/pom.xml b/presto-spark-classloader-interface/pom.xml
index cc42a97c56ba6..8f93be12818a4 100644
--- a/presto-spark-classloader-interface/pom.xml
+++ b/presto-spark-classloader-interface/pom.xml
@@ -13,6 +13,7 @@
${project.parent.basedir}
true
+ 2
@@ -21,14 +22,59 @@
spark-core
provided
+
+
+ com.facebook.presto
+ presto-spark-classloader-spark${dep.pos.classloader.module-name.suffix}
+
+
com.google.guava
guava
+
com.facebook.airlift
units
+
+
+ spark3
+
+
+
+ spark-version
+ 3
+
+
+
+
+ 3
+
+
+
+
+
+ com.facebook.presto.spark
+ spark-core
+ 3.4.1-1
+ compile
+
+
+
+
+
+
+ org.scala-lang
+ scala-library
+ 2.13.8
+ provided
+
+
+
+
+
+
diff --git a/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkNativeTaskRdd.java b/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkNativeTaskRdd.java
index a524767a88ead..a7b6abd69eb76 100644
--- a/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkNativeTaskRdd.java
+++ b/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkNativeTaskRdd.java
@@ -29,6 +29,7 @@
import org.apache.spark.storage.BlockId;
import org.apache.spark.storage.BlockManagerId;
import scala.Tuple2;
+import scala.Tuple3;
import scala.collection.Iterator;
import scala.collection.Seq;
@@ -38,12 +39,13 @@
import java.util.Optional;
import java.util.stream.Collectors;
+import static com.facebook.presto.spark.classloader_interface.PrestoSparkUtils.asJavaCollection;
+import static com.facebook.presto.spark.classloader_interface.PrestoSparkUtils.getMapSizesByExecutorId;
+import static com.facebook.presto.spark.classloader_interface.PrestoSparkUtils.seqAsJavaList;
import static com.facebook.presto.spark.classloader_interface.ScalaUtils.emptyScalaIterator;
import static com.google.common.base.Preconditions.checkState;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
-import static scala.collection.JavaConversions.asJavaCollection;
-import static scala.collection.JavaConversions.seqAsJavaList;
/**
* PrestoSparkTaskRdd represents execution of Presto stage, it contains:
@@ -173,34 +175,34 @@ private Optional getShuffleWriteDescriptor(in
private List getBlockIds(ShuffledRDDPartition partition, ShuffleHandle shuffleHandle)
{
MapOutputTracker mapOutputTracker = SparkEnv.get().mapOutputTracker();
- Collection>>> mapSizes = asJavaCollection(mapOutputTracker.getMapSizesByExecutorId(
- shuffleHandle.shuffleId(), partition.idx(), partition.idx() + 1));
+ Collection>>> mapSizes = getMapSizesByExecutorId(mapOutputTracker,
+ shuffleHandle.shuffleId(), partition.idx(), partition.idx() + 1);
return mapSizes.stream().map(item -> item._1.executorId()).collect(Collectors.toList());
}
private List getPartitionIds(ShuffledRDDPartition partition, ShuffleHandle shuffleHandle)
{
MapOutputTracker mapOutputTracker = SparkEnv.get().mapOutputTracker();
- Collection>>> mapSizes = asJavaCollection(mapOutputTracker.getMapSizesByExecutorId(
- shuffleHandle.shuffleId(), partition.idx(), partition.idx() + 1));
+ Collection>>> mapSizes = getMapSizesByExecutorId(mapOutputTracker,
+ shuffleHandle.shuffleId(), partition.idx(), partition.idx() + 1);
return mapSizes.stream()
.map(item -> asJavaCollection(item._2))
.flatMap(Collection::stream)
- .map(i -> i._1.toString())
+ .map(i -> i._1().toString())
.collect(Collectors.toList());
}
private List getPartitionSize(ShuffledRDDPartition partition, ShuffleHandle shuffleHandle)
{
MapOutputTracker mapOutputTracker = SparkEnv.get().mapOutputTracker();
- Collection>>> mapSizes = asJavaCollection(mapOutputTracker.getMapSizesByExecutorId(
- shuffleHandle.shuffleId(), partition.idx(), partition.idx() + 1));
+ Collection>>> mapSizes = getMapSizesByExecutorId(mapOutputTracker,
+ shuffleHandle.shuffleId(), partition.idx(), partition.idx() + 1);
//Each partition/BlockManagerId can contain multiple blocks (with BlockId), here sums up all the blocks from each BlockManagerId/Partition
return mapSizes.stream()
.map(
item -> seqAsJavaList(item._2)
.stream()
- .map(item2 -> ((Long) item2._2))
+ .map(item2 -> ((Long) item2._2()))
.reduce(0L, Long::sum))
.collect(Collectors.toList());
}
diff --git a/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkShuffleSerializer.java b/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkShuffleSerializer.java
index d410111018ea7..a906922b415b2 100644
--- a/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkShuffleSerializer.java
+++ b/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkShuffleSerializer.java
@@ -81,7 +81,11 @@ public ByteBuffer serialize(T input, ClassTag classTag)
@Override
public T deserialize(ByteBuffer buffer, ClassTag classTag)
{
- throw new UnsupportedOperationException("this method is never used by shuffle");
+ row.setArray(buffer.array());
+ row.setOffset(buffer.arrayOffset());
+ row.setLength(buffer.array().length - buffer.arrayOffset());
+ row.setBuffer(buffer);
+ return (T) tuple;
}
public T deserialize(byte[] array, int offset, int length, ClassTag classTag)
@@ -93,9 +97,13 @@ public T deserialize(byte[] array, int offset, int length, ClassTag class
}
@Override
- public T deserialize(ByteBuffer bytes, ClassLoader loader, ClassTag classTag)
+ public T deserialize(ByteBuffer buffer, ClassLoader loader, ClassTag classTag)
{
- throw new UnsupportedOperationException("this method is never used by shuffle");
+ row.setArray(buffer.array());
+ row.setOffset(buffer.arrayOffset());
+ row.setLength(buffer.array().length - buffer.arrayOffset());
+ row.setBuffer(buffer);
+ return (T) tuple;
}
}
diff --git a/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkTaskRdd.java b/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkTaskRdd.java
index 9df7c62b2b366..dda535f1b2ced 100644
--- a/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkTaskRdd.java
+++ b/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkTaskRdd.java
@@ -21,7 +21,7 @@
import org.apache.spark.rdd.ZippedPartitionsPartition;
import scala.Tuple2;
import scala.collection.Iterator;
-import scala.collection.Seq;
+import scala.collection.immutable.Seq;
import scala.reflect.ClassTag;
import java.util.ArrayList;
@@ -30,12 +30,13 @@
import java.util.Map;
import java.util.Optional;
+import static com.facebook.presto.spark.classloader_interface.PrestoSparkUtils.asScalaBuffer;
+import static com.facebook.presto.spark.classloader_interface.PrestoSparkUtils.seqAsJavaList;
+import static com.facebook.presto.spark.classloader_interface.PrestoSparkUtils.toImmutableSeq;
import static com.facebook.presto.spark.classloader_interface.ScalaUtils.emptyScalaIterator;
import static java.lang.String.format;
import static java.util.Collections.unmodifiableMap;
import static java.util.Objects.requireNonNull;
-import static scala.collection.JavaConversions.asScalaBuffer;
-import static scala.collection.JavaConversions.seqAsJavaList;
/**
* PrestoSparkTaskRdd represents execution of Presto stage, it contains:
@@ -101,7 +102,7 @@ private static Seq> getRDDSequence(Optional tas
{
List> list = new ArrayList<>(shuffleInputRdds);
taskSourceRdd.ifPresent(list::add);
- return asScalaBuffer(list).toSeq();
+ return toImmutableSeq(asScalaBuffer(list).toSeq());
}
private static ClassTag fakeClassTag()
diff --git a/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkTaskSourceRdd.java b/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkTaskSourceRdd.java
index 97636d1168705..3a17007c9b97c 100644
--- a/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkTaskSourceRdd.java
+++ b/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkTaskSourceRdd.java
@@ -27,9 +27,9 @@
import java.util.Collections;
import java.util.List;
+import static com.facebook.presto.spark.classloader_interface.PrestoSparkUtils.asScalaBuffer;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toList;
-import static scala.collection.JavaConversions.asScalaBuffer;
public class PrestoSparkTaskSourceRdd
extends RDD
diff --git a/presto-spark-classloader-spark2/pom.xml b/presto-spark-classloader-spark2/pom.xml
new file mode 100644
index 0000000000000..efbe208415200
--- /dev/null
+++ b/presto-spark-classloader-spark2/pom.xml
@@ -0,0 +1,32 @@
+
+
+
+ presto-root
+ com.facebook.presto
+ 0.295-SNAPSHOT
+
+ 4.0.0
+
+ presto-spark-classloader-spark2
+ presto-spark-classloader-spark2
+
+
+ ${project.parent.basedir}
+
+
+
+
+
+ com.facebook.presto.spark
+ spark-core
+ provided
+
+
+
+ com.google.guava
+ guava
+
+
+
+
+
\ No newline at end of file
diff --git a/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkNativeExecutionShuffleManager.java b/presto-spark-classloader-spark2/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkNativeExecutionShuffleManager.java
similarity index 100%
rename from presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkNativeExecutionShuffleManager.java
rename to presto-spark-classloader-spark2/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkNativeExecutionShuffleManager.java
diff --git a/presto-spark-classloader-spark2/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkUtils.java b/presto-spark-classloader-spark2/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkUtils.java
new file mode 100644
index 0000000000000..0fedbb986cb07
--- /dev/null
+++ b/presto-spark-classloader-spark2/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkUtils.java
@@ -0,0 +1,105 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.spark.classloader_interface;
+
+import org.apache.spark.MapOutputTracker;
+import org.apache.spark.storage.BlockId;
+import org.apache.spark.storage.BlockManagerId;
+import scala.Tuple2;
+import scala.Tuple3;
+import scala.collection.Iterable;
+import scala.collection.JavaConversions;
+import scala.collection.Seq;
+import scala.collection.mutable.Buffer;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.stream.Collectors;
+
+public class PrestoSparkUtils
+{
+ private PrestoSparkUtils()
+ {
+ }
+
+ public static Collection asJavaCollection(Iterable iterable)
+ {
+ return JavaConversions.asJavaCollection(iterable);
+ }
+
+ public static List seqAsJavaList(Seq iterable)
+ {
+ return JavaConversions.seqAsJavaList(iterable);
+ }
+
+ public static Buffer asScalaBuffer(List list)
+ {
+ return JavaConversions.asScalaBuffer(list);
+ }
+
+ public static Collection>>> getMapSizesByExecutorId(
+ MapOutputTracker mapOutputTracker, int shuffleId, int startPartitionId, int endPartitionId)
+ {
+ return convertCollection(asJavaCollection(mapOutputTracker.getMapSizesByExecutorId(
+ shuffleId, startPartitionId, endPartitionId).toList()));
+ }
+
+ /**
+ * Converts a Scala Seq to a Scala immutable Seq.
+ *
+ * @param seq the Scala sequence to convert
+ * @param the type of elements in the sequence
+ * @return an immutable Scala sequence containing the same elements
+ */
+ public static scala.collection.immutable.Seq toImmutableSeq(Seq seq)
+ {
+ if (seq instanceof scala.collection.immutable.Seq) {
+ return (scala.collection.immutable.Seq) seq;
+ }
+ else {
+ List javaList = seqAsJavaList(seq);
+ return asScalaBuffer(javaList).toList();
+ }
+ }
+
+ /**
+ * Utility method to convert a Collection of Tuple2>> to
+ * a Collection of Tuple2>>
+ *
+ * @param inputCollection The original Collection of Tuple2>> elements
+ * @return A Collection of Tuple2>> elements
+ */
+ public static Collection>>> convertCollection(
+ Collection>>> inputCollection)
+ {
+ return inputCollection.stream()
+ .map(entry -> new Tuple2<>(entry._1(), convertTuple2SeqToTuple3Seq(entry._2(), null)))
+ .collect(Collectors.toList());
+ }
+
+ /**
+ * Utility method to convert a Seq of Tuple2 to a Seq of Tuple3 by adding a third element
+ *
+ * @param tuple2Seq The original Seq of Tuple2 elements
+ * @param thirdElement The third element to add to each Tuple2
+ * @return A Seq of Tuple3 elements with the third element added
+ */
+ public static Seq> convertTuple2SeqToTuple3Seq(Seq> tuple2Seq, Object thirdElement)
+ {
+ List> tuple3List = seqAsJavaList(tuple2Seq).stream()
+ .map(tuple2 -> new Tuple3<>(tuple2._1, tuple2._2, thirdElement))
+ .collect(Collectors.toList());
+ return asScalaBuffer(tuple3List).toSeq();
+ }
+}
diff --git a/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/ScalaUtils.java b/presto-spark-classloader-spark2/src/main/java/com/facebook/presto/spark/classloader_interface/ScalaUtils.java
similarity index 100%
rename from presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/ScalaUtils.java
rename to presto-spark-classloader-spark2/src/main/java/com/facebook/presto/spark/classloader_interface/ScalaUtils.java
diff --git a/presto-spark-classloader-spark3/pom.xml b/presto-spark-classloader-spark3/pom.xml
new file mode 100644
index 0000000000000..a9bcd8232da78
--- /dev/null
+++ b/presto-spark-classloader-spark3/pom.xml
@@ -0,0 +1,40 @@
+
+
+
+ presto-root
+ com.facebook.presto
+ 0.295-SNAPSHOT
+
+ 4.0.0
+
+ presto-spark-classloader-spark3
+ presto-spark-classloader-spark3
+
+
+ ${project.parent.basedir}
+
+
+
+
+
+ com.google.guava
+ guava
+
+
+
+ com.facebook.presto.spark
+ spark-core
+ 3.4.1-1
+ compile
+
+
+
+ org.scala-lang
+ scala-library
+ 2.13.8
+ provided
+
+
+
+
+
\ No newline at end of file
diff --git a/presto-spark-classloader-spark3/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkNativeExecutionShuffleManager.java b/presto-spark-classloader-spark3/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkNativeExecutionShuffleManager.java
new file mode 100644
index 0000000000000..ba994cf5c3200
--- /dev/null
+++ b/presto-spark-classloader-spark3/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkNativeExecutionShuffleManager.java
@@ -0,0 +1,258 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.spark.classloader_interface;
+
+import org.apache.spark.ShuffleDependency;
+import org.apache.spark.SparkConf;
+import org.apache.spark.SparkEnv;
+import org.apache.spark.TaskContext;
+import org.apache.spark.scheduler.MapStatus;
+import org.apache.spark.scheduler.MapStatus$;
+import org.apache.spark.shuffle.BaseShuffleHandle;
+import org.apache.spark.shuffle.ShuffleBlockResolver;
+import org.apache.spark.shuffle.ShuffleHandle;
+import org.apache.spark.shuffle.ShuffleManager;
+import org.apache.spark.shuffle.ShuffleReadMetricsReporter;
+import org.apache.spark.shuffle.ShuffleReader;
+import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
+import org.apache.spark.shuffle.ShuffleWriter;
+import org.apache.spark.shuffle.sort.BypassMergeSortShuffleHandle;
+import org.apache.spark.storage.BlockManager;
+import scala.Option;
+import scala.Product2;
+import scala.collection.Iterator;
+
+import java.io.IOException;
+import java.lang.reflect.InvocationTargetException;
+import java.util.Arrays;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.concurrent.ConcurrentHashMap;
+
+import static com.facebook.presto.spark.classloader_interface.ScalaUtils.emptyScalaIterator;
+import static com.google.common.base.Preconditions.checkState;
+import static java.lang.String.format;
+import static java.util.Objects.requireNonNull;
+
+/*
+ * {@link PrestoSparkNativeExecutionShuffleManager} is the shuffle manager implementing the Spark shuffle manager interface specifically for native execution. The reasons we have this
+ * new shuffle manager are:
+ * 1. To bypass calling into Spark java shuffle writer/reader since the actual shuffle read/write will happen in C++ side. In PrestoSparkNativeExecutionShuffleManager, we registered
+ * a pair of no-op shuffle reader/writer to hook-up with regular Spark shuffle workflow.
+ * 2. To capture the shuffle metadata (eg. {@link ShuffleHandle}) for later use. These metadata are only available during shuffle writer creation internally which is beyond the whole
+ * Presto-Spark native execution flow. By using the {@link PrestoSparkNativeExecutionShuffleManager}, we capture and store these metadata inside the shuffle manager and provide
+ * the APIs to allow native execution runtime access.
+ * */
+public class PrestoSparkNativeExecutionShuffleManager
+ implements ShuffleManager
+{
+ private final Map partitionIdToShuffleHandle = new ConcurrentHashMap<>();
+ private final Map> shuffleIdToBaseShuffleHandle = new ConcurrentHashMap<>();
+ private final ShuffleManager fallbackShuffleManager;
+ private static final String FALLBACK_SPARK_SHUFFLE_MANAGER = "spark.fallback.shuffle.manager";
+
+ public PrestoSparkNativeExecutionShuffleManager(SparkConf conf)
+ {
+ fallbackShuffleManager = instantiateClass(conf.get(FALLBACK_SPARK_SHUFFLE_MANAGER), conf);
+ }
+
+ // Create an instance of the class with the given name, possibly initializing it with our conf
+ private static T instantiateClass(String className, SparkConf conf)
+ {
+ try {
+ return (T) (Class.forName(className).getConstructor(SparkConf.class).newInstance(conf));
+ }
+ catch (ClassNotFoundException | InstantiationException | IllegalAccessException | InvocationTargetException | NoSuchMethodException e) {
+ throw new RuntimeException(format("%s class not found", className), e);
+ }
+ }
+
+ protected void registerShuffleHandle(BaseShuffleHandle handle, int stageId, long mapId)
+ {
+ partitionIdToShuffleHandle.put(new StageAndMapId(stageId, mapId), handle);
+ shuffleIdToBaseShuffleHandle.put(handle.shuffleId(), handle);
+ }
+
+ protected void unregisterShuffleHandle(int shuffleId, int stageId, long mapId)
+ {
+ partitionIdToShuffleHandle.remove(new StageAndMapId(stageId, mapId));
+ shuffleIdToBaseShuffleHandle.remove(shuffleId);
+ }
+
+ @Override
+ public ShuffleHandle registerShuffle(int shuffleId, ShuffleDependency dependency)
+ {
+ return fallbackShuffleManager.registerShuffle(shuffleId, dependency);
+ }
+
+ @Override
+ public ShuffleWriter getWriter(ShuffleHandle handle, long mapId, TaskContext context, ShuffleWriteMetricsReporter metrics)
+ {
+ checkState(
+ requireNonNull(handle, "handle is null") instanceof BypassMergeSortShuffleHandle,
+ "class %s is not instance of BypassMergeSortShuffleHandle", handle.getClass().getName());
+ BaseShuffleHandle, ?, ?> baseShuffleHandle = (BaseShuffleHandle, ?, ?>) handle;
+ int shuffleId = baseShuffleHandle.shuffleId();
+ int stageId = context.stageId();
+ registerShuffleHandle(baseShuffleHandle, stageId, mapId);
+ return new EmptyShuffleWriter<>(
+ baseShuffleHandle.dependency().partitioner().numPartitions(),
+ () -> unregisterShuffleHandle(shuffleId, stageId, mapId));
+ }
+
+ @Override
+ public ShuffleReader getReader(ShuffleHandle handle, int startPartition, int endPartition, TaskContext context, ShuffleReadMetricsReporter metrics)
+ {
+ return new EmptyShuffleReader<>();
+ }
+
+ @Override
+ public ShuffleReader getReader(ShuffleHandle handle, int startMapIndex, int endMapIndex, int startPartition, int endPartition, TaskContext context, ShuffleReadMetricsReporter metrics)
+ {
+ return new EmptyShuffleReader<>();
+ }
+
+ @Override
+ public boolean unregisterShuffle(int shuffleId)
+ {
+ fallbackShuffleManager.unregisterShuffle(shuffleId);
+ return true;
+ }
+
+ @Override
+ public ShuffleBlockResolver shuffleBlockResolver()
+ {
+ return fallbackShuffleManager.shuffleBlockResolver();
+ }
+
+ @Override
+ public void stop()
+ {
+ fallbackShuffleManager.stop();
+ }
+
+ /*
+ * This method can only be called inside Rdd's compute method otherwise the shuffleDependencyMap may not contain corresponding ShuffleHandle object.
+ * The reason is that in Spark's ShuffleMapTask, it's guaranteed to call writer.getWriter(handle, mapId, context) first before calling the Rdd.compute()
+ * method, therefore, the ShuffleHandle object will always be added to shuffleDependencyMap in getWriter before Rdd.compute().
+ */
+ public Optional getShuffleHandle(int stageId, int mapId)
+ {
+ return Optional.ofNullable(partitionIdToShuffleHandle.getOrDefault(new StageAndMapId(stageId, mapId), null));
+ }
+
+ public boolean hasRegisteredShuffleHandles()
+ {
+ return !partitionIdToShuffleHandle.isEmpty() || !shuffleIdToBaseShuffleHandle.isEmpty();
+ }
+
+ public int getNumOfPartitions(int shuffleId)
+ {
+ if (!shuffleIdToBaseShuffleHandle.containsKey(shuffleId)) {
+ throw new RuntimeException(format("shuffleId=[%s] is not registered", shuffleId));
+ }
+ return shuffleIdToBaseShuffleHandle.get(shuffleId).dependency().partitioner().numPartitions();
+ }
+
+ static class EmptyShuffleReader
+ implements ShuffleReader
+ {
+ @Override
+ public Iterator> read()
+ {
+ return emptyScalaIterator();
+ }
+ }
+
+ static class EmptyShuffleWriter
+ extends ShuffleWriter
+ {
+ private final long[] mapStatus;
+ private final Runnable onStop;
+ private static final long DEFAULT_MAP_STATUS = 1L;
+
+ public EmptyShuffleWriter(int totalMapStages, Runnable onStop)
+ {
+ this.mapStatus = new long[totalMapStages];
+ this.onStop = requireNonNull(onStop, "onStop is null");
+ Arrays.fill(mapStatus, DEFAULT_MAP_STATUS);
+ }
+
+ @Override
+ public void write(Iterator> records)
+ throws IOException
+ {
+ if (records.hasNext()) {
+ throw new RuntimeException("EmptyShuffleWriter can only take empty write input.");
+ }
+ }
+
+ @Override
+ public Option stop(boolean success)
+ {
+ onStop.run();
+ BlockManager blockManager = SparkEnv.get().blockManager();
+ return Option.apply(
+ MapStatus$.MODULE$.apply(blockManager.blockManagerId(), mapStatus, 0L));
+ }
+
+ @Override
+ public long[] getPartitionLengths()
+ {
+ return mapStatus;
+ }
+ }
+
+ public static class StageAndMapId
+ {
+ private final int stageId;
+ private final long mapId;
+
+ public StageAndMapId(int stageId, long mapId)
+ {
+ this.stageId = stageId;
+ this.mapId = mapId;
+ }
+
+ public int getStageId()
+ {
+ return stageId;
+ }
+
+ public long getMapId()
+ {
+ return mapId;
+ }
+
+ @Override
+ public boolean equals(Object o)
+ {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ StageAndMapId that = (StageAndMapId) o;
+ return stageId == that.stageId && mapId == that.mapId;
+ }
+
+ @Override
+ public int hashCode()
+ {
+ return Objects.hash(stageId, mapId);
+ }
+ }
+}
diff --git a/presto-spark-classloader-spark3/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkUtils.java b/presto-spark-classloader-spark3/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkUtils.java
new file mode 100644
index 0000000000000..81d0cdba34b40
--- /dev/null
+++ b/presto-spark-classloader-spark3/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkUtils.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.spark.classloader_interface;
+
+import org.apache.spark.MapOutputTracker;
+import org.apache.spark.storage.BlockId;
+import org.apache.spark.storage.BlockManagerId;
+import scala.Tuple2;
+import scala.Tuple3;
+import scala.collection.Iterable;
+import scala.collection.JavaConverters;
+import scala.collection.Seq;
+import scala.collection.mutable.Buffer;
+
+import java.util.Collection;
+import java.util.List;
+
+public class PrestoSparkUtils
+{
+ private PrestoSparkUtils()
+ {
+ }
+
+ public static Collection asJavaCollection(Iterable iterable)
+ {
+ return JavaConverters.asJavaCollection(iterable);
+ }
+
+ public static List seqAsJavaList(Seq iterable)
+ {
+ return JavaConverters.seqAsJavaList(iterable);
+ }
+
+ public static Buffer asScalaBuffer(List list)
+ {
+ return JavaConverters.asScalaBuffer(list);
+ }
+
+ public static Collection>>> getMapSizesByExecutorId(
+ MapOutputTracker mapOutputTracker, int shuffleId, int startPartitionId, int endPartitionId)
+ {
+ return asJavaCollection(mapOutputTracker.getMapSizesByExecutorId(
+ shuffleId, 0, Integer.MAX_VALUE, startPartitionId, endPartitionId).toList());
+ }
+
+ /**
+ * Converts a Scala Seq to a Scala immutable Seq.
+ *
+ * @param seq the Scala sequence to convert
+ * @param the type of elements in the sequence
+ * @return an immutable Scala sequence containing the same elements
+ */
+ public static scala.collection.immutable.Seq toImmutableSeq(Seq seq)
+ {
+ if (seq instanceof scala.collection.immutable.Seq) {
+ return (scala.collection.immutable.Seq) seq;
+ }
+ else {
+ List javaList = seqAsJavaList(seq);
+ return asScalaBuffer(javaList).toList();
+ }
+ }
+}
diff --git a/presto-spark-classloader-spark3/src/main/java/com/facebook/presto/spark/classloader_interface/ScalaUtils.java b/presto-spark-classloader-spark3/src/main/java/com/facebook/presto/spark/classloader_interface/ScalaUtils.java
new file mode 100644
index 0000000000000..88b92901716e7
--- /dev/null
+++ b/presto-spark-classloader-spark3/src/main/java/com/facebook/presto/spark/classloader_interface/ScalaUtils.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.spark.classloader_interface;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.NoSuchElementException;
+
+import static java.util.Collections.unmodifiableList;
+
+public class ScalaUtils
+{
+ private ScalaUtils() {}
+
+ public static List collectScalaIterator(scala.collection.Iterator iterator)
+ {
+ List result = new ArrayList<>();
+ while (iterator.hasNext()) {
+ result.add(iterator.next());
+ }
+ return unmodifiableList(result);
+ }
+
+ public static scala.collection.Iterator emptyScalaIterator()
+ {
+ return new scala.collection.AbstractIterator()
+ {
+ @Override
+ public boolean hasNext()
+ {
+ return false;
+ }
+
+ @Override
+ public T next()
+ {
+ throw new NoSuchElementException();
+ }
+ };
+ }
+}
diff --git a/presto-spark-launcher/pom.xml b/presto-spark-launcher/pom.xml
index be5bae5b73055..0b0dd2bb59d92 100644
--- a/presto-spark-launcher/pom.xml
+++ b/presto-spark-launcher/pom.xml
@@ -48,8 +48,39 @@
spark-core
provided
+
+
+
+ spark3
+
+
+
+ spark-version
+ 3
+
+
+
+
+
+ com.facebook.presto.spark
+ spark-core
+ 3.4.1-1
+ provided
+
+
+
+ org.scala-lang
+ scala-library
+ 2.13.8
+ provided
+
+
+
+
+
+