diff --git a/.github/workflows/prestocpp-linux-build-and-unit-test.yml b/.github/workflows/prestocpp-linux-build-and-unit-test.yml
index 0dcdcd549795f..f6bd5744658ae 100644
--- a/.github/workflows/prestocpp-linux-build-and-unit-test.yml
+++ b/.github/workflows/prestocpp-linux-build-and-unit-test.yml
@@ -242,6 +242,84 @@ jobs:
-Duser.timezone=America/Bahia_Banderas \
-T1C
+ prestocpp-linux-presto-on-spark-e2e-tests:
+ needs: prestocpp-linux-build-for-test
+ runs-on: ubuntu-22.04
+ strategy:
+ fail-fast: false
+ matrix:
+ storage-format: [ "PARQUET", "DWRF" ]
+ enable-sidecar: [ "true", "false" ]
+ container:
+ image: prestodb/presto-native-dependency:0.293-20250522140509-484b00e
+ env:
+ MAVEN_OPTS: "-Xmx4G -XX:+ExitOnOutOfMemoryError"
+ MAVEN_FAST_INSTALL: "-B -V --quiet -T 1C -DskipTests -Dair.check.skip-all -Dmaven.javadoc.skip=true"
+ MAVEN_TEST: "-B -Dair.check.skip-all -Dmaven.javadoc.skip=true -DLogTestDurationListener.enabled=true --fail-at-end"
+ steps:
+ - uses: actions/checkout@v4
+
+ - name: Fix git permissions
+ # Usually actions/checkout does this but as we run in a container
+ # it doesn't work
+ run: git config --global --add safe.directory ${GITHUB_WORKSPACE}
+
+ - name: Download artifacts
+ uses: actions/download-artifact@v4
+ with:
+ name: presto-native-build
+ path: presto-native-execution/_build/release
+
+ # Permissions are lost when uploading. Details here: https://github.com/actions/upload-artifact/issues/38
+ - name: Restore execute permissions and library path
+ run: |
+ chmod +x ${GITHUB_WORKSPACE}/presto-native-execution/_build/release/presto_cpp/main/presto_server
+ chmod +x ${GITHUB_WORKSPACE}/presto-native-execution/_build/release/velox/velox/functions/remote/server/velox_functions_remote_server_main
+ # Ensure transitive dependency libboost-iostreams is found.
+ ldconfig /usr/local/lib
+
+ - name: Install OpenJDK8
+ uses: actions/setup-java@v4
+ with:
+ distribution: 'temurin'
+ java-version: '8.0.442'
+ cache: 'maven'
+ - name: Download nodejs to maven cache
+ run: .github/bin/download_nodejs
+
+ - name: Maven install
+ env:
+ # Use different Maven options to install.
+ MAVEN_OPTS: "-Xmx2G -XX:+ExitOnOutOfMemoryError"
+ run: |
+ for i in $(seq 1 3); do ./mvnw clean install $MAVEN_FAST_INSTALL -pl 'presto-native-tests' -am && s=0 && break || s=$? && sleep 10; done; (exit $s)
+
+ - name: Run presto-on-spark native tests
+ run: |
+ export PRESTO_SERVER_PATH="${GITHUB_WORKSPACE}/presto-native-execution/_build/release/presto_cpp/main/presto_server"
+ export TESTFILES=`find ./presto-native-execution/src/test -type f -name 'TestPrestoSpark*.java'`
+ # Convert file paths to comma separated class names
+ export TESTCLASSES=TestPrestoSparkExpressionCompiler,TestPrestoSparkNativeBitwiseFunctionQueries,TestPrestoSparkNativeTpchConnectorQueries,TestPrestoSparkNativeSimpleQueries,TestPrestoSparkSqlFunctions,TestPrestoSparkNativeTpchQueries
+ for test_file in $TESTFILES
+ do
+ tmp=${test_file##*/}
+ test_class=${tmp%%\.*}
+ export TESTCLASSES="${TESTCLASSES},$test_class"
+ done
+ export TESTCLASSES=${TESTCLASSES#,}
+ echo "TESTCLASSES = $TESTCLASSES"
+
+ mvn test \
+ ${MAVEN_TEST} \
+ -pl 'presto-native-execution' \
+ -DstorageFormat=${{ matrix.storage-format }} \
+ -DsidecarEnabled=${{ matrix.enable-sidecar }} \
+ -Dtest="${TESTCLASSES}" \
+ -DPRESTO_SERVER=${PRESTO_SERVER_PATH} \
+ -DDATA_DIR=${RUNNER_TEMP} \
+ -Duser.timezone=America/Bahia_Banderas \
+ -T1C
+
prestocpp-linux-presto-sidecar-tests:
needs: prestocpp-linux-build-for-test
runs-on: ubuntu-22.04
diff --git a/pom.xml b/pom.xml
index e7301cc38d99a..af503b8e5adde 100644
--- a/pom.xml
+++ b/pom.xml
@@ -2440,9 +2440,9 @@
- com.facebook.presto.spark
- spark-core
- 2.0.2-6
+ org.apache.spark
+ spark-core_2.13
+ 3.4.0
provided
diff --git a/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java
index 34e9330eab12e..3bcf84c00cadd 100644
--- a/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java
+++ b/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java
@@ -1560,6 +1560,26 @@ public SystemSessionProperties(
"Enable execution on native engine",
featuresConfig.isNativeExecutionEnabled(),
true),
+ stringProperty(
+ NATIVE_EXECUTION_EXECUTABLE_PATH,
+ "The native engine executable file path for native engine execution",
+ featuresConfig.getNativeExecutionExecutablePath(),
+ false),
+ stringProperty(
+ NATIVE_EXECUTION_PROGRAM_ARGUMENTS,
+ "Program arguments for native engine execution. The main target use case for this " +
+ "property is to control logging levels using glog flags. E,g, to enable verbose mode, add " +
+ "'--v 1'. More advanced glog gflags usage can be found at " +
+ "https://rpg.ifi.uzh.ch/docs/glog.html\n" +
+ "e.g. --vmodule=mapreduce=2,file=1,gfs*=3 --v=0\n" +
+ "will:\n" +
+ "\n" +
+ "a. Print VLOG(2) and lower messages from mapreduce.{h,cc}\n" +
+ "b. Print VLOG(1) and lower messages from file.{h,cc}\n" +
+ "c. Print VLOG(3) and lower messages from files prefixed with \"gfs\"\n" +
+ "d. Print VLOG(0) and lower messages from elsewhere",
+ featuresConfig.getNativeExecutionProgramArguments(),
+ false),
booleanProperty(
NATIVE_EXECUTION_PROCESS_REUSE_ENABLED,
"Enable reuse the native process within the same JVM",
diff --git a/presto-native-execution/pom.xml b/presto-native-execution/pom.xml
index cc69d2dd26e9c..6737b0ebefaf1 100644
--- a/presto-native-execution/pom.xml
+++ b/presto-native-execution/pom.xml
@@ -33,6 +33,11 @@
guava
+
+ com.google.inject
+ guice
+
+
org.testng
testng
@@ -206,8 +211,8 @@
- com.facebook.presto.spark
- spark-core
+ org.apache.spark
+ spark-core_2.13
test
diff --git a/presto-native-execution/src/test/java/com/facebook/presto/spark/PrestoSparkNativeQueryRunnerUtils.java b/presto-native-execution/src/test/java/com/facebook/presto/spark/PrestoSparkNativeQueryRunnerUtils.java
new file mode 100644
index 0000000000000..7313ac9f87414
--- /dev/null
+++ b/presto-native-execution/src/test/java/com/facebook/presto/spark/PrestoSparkNativeQueryRunnerUtils.java
@@ -0,0 +1,237 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.spark;
+
+import com.facebook.airlift.log.Logging;
+import com.facebook.presto.hive.metastore.Database;
+import com.facebook.presto.hive.metastore.ExtendedHiveMetastore;
+import com.facebook.presto.nativeworker.PrestoNativeQueryRunnerUtils;
+import com.facebook.presto.spark.execution.nativeprocess.NativeExecutionModule;
+import com.facebook.presto.spark.execution.property.NativeExecutionConnectorConfig;
+import com.facebook.presto.spi.security.PrincipalType;
+import com.facebook.presto.testing.QueryRunner;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.inject.Module;
+
+import java.io.IOException;
+import java.io.UncheckedIOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.Map;
+import java.util.Optional;
+
+import static com.facebook.airlift.log.Level.WARN;
+import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.getNativeWorkerHiveProperties;
+import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.getNativeWorkerSystemProperties;
+import static com.facebook.presto.nativeworker.PrestoNativeQueryRunnerUtils.getNativeQueryRunnerParameters;
+import static com.facebook.presto.spark.PrestoSparkQueryRunner.METASTORE_CONTEXT;
+import static java.lang.String.format;
+import static java.nio.file.Files.createTempDirectory;
+
+/**
+ * Following JVM argument is needed to run Spark native tests.
+ *
+ * - PRESTO_SERVER
+ * - This tells Spark where to find the Presto native binary to launch the process.
+ * Example: -DPRESTO_SERVER=/path/to/native/process/bin
+ *
+ * - DATA_DIR
+ * - Optional path to store TPC-H tables used in the test. If this directory is empty, it will be
+ * populated. If tables already exists, they will be reused.
+ *
+ * Tests can be running in Interactive Debugging Mode that allows for easier debugging
+ * experience. Instead of launching its own native process, the test will connect to an existing
+ * native process. This gives developers flexibility to connect IDEA and debuggers to the native process.
+ * Enable this mode by setting NATIVE_PORT JVM argument.
+ *
+ * - NATIVE_PORT
+ * - This is the port your externally launched native process listens to. It is used to tell Spark where to send
+ * requests. This port number has to be the same as to which your externally launched process listens.
+ * Example: -DNATIVE_PORT=7777.
+ * When NATIVE_PORT is specified, PRESTO_SERVER argument is not requires and is ignored if specified.
+ *
+ * For test queries requiring shuffle, the disk-based local shuffle will be used.
+ */
+public class PrestoSparkNativeQueryRunnerUtils
+{
+ private static final int AVAILABLE_CPU_COUNT = 4;
+ private static final String SPARK_SHUFFLE_MANAGER = "spark.shuffle.manager";
+ private static final String FALLBACK_SPARK_SHUFFLE_MANAGER = "spark.fallback.shuffle.manager";
+ private static final String DEFAULT_STORAGE_FORMAT = "DWRF";
+ private static Optional dataDirectory = Optional.empty();
+
+ private PrestoSparkNativeQueryRunnerUtils() {}
+
+ public static Map getNativeExecutionSessionConfigs()
+ {
+ ImmutableMap.Builder builder = new ImmutableMap.Builder()
+ // Do not use default Prestissimo config files. Presto-Spark will generate the configs on-the-fly.
+ .put("catalog.config-dir", "/")
+ .put("task.info-update-interval", "100ms")
+ .put("spark.initial-partition-count", "1")
+ .put("register-test-functions", "true")
+ .put("native-execution-program-arguments", "--logtostderr=1 --minloglevel=3")
+ .put("spark.partition-count-auto-tune-enabled", "false");
+
+ if (System.getProperty("NATIVE_PORT") == null) {
+ builder.put("native-execution-executable-path", getNativeQueryRunnerParameters().serverBinary.toString());
+ }
+
+ try {
+ builder.put("native-execution-broadcast-base-path",
+ Files.createTempDirectory("native_broadcast").toAbsolutePath().toString());
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException("Error creating temporary directory for broadcast", e);
+ }
+
+ return builder.build();
+ }
+
+ public static PrestoSparkQueryRunner createHiveRunner()
+ {
+ PrestoSparkQueryRunner queryRunner = createRunner("hive", new NativeExecutionModule());
+ PrestoNativeQueryRunnerUtils.setupJsonFunctionNamespaceManager(queryRunner, "external_functions.json", "json");
+
+ return queryRunner;
+ }
+
+ private static PrestoSparkQueryRunner createRunner(String defaultCatalog, NativeExecutionModule nativeExecutionModule)
+ {
+ // Increases log level to reduce log spamming while running test.
+ customizeLogging();
+ return createRunner(
+ defaultCatalog,
+ Optional.of(getBaseDataPath()),
+ getNativeExecutionSessionConfigs(),
+ getNativeExecutionShuffleConfigs(),
+ ImmutableList.of(nativeExecutionModule));
+ }
+
+ // Similar to createPrestoSparkNativeQueryRunner, but with custom connector config and without jsonFunctionNamespaceManager
+ public static PrestoSparkQueryRunner createTpchRunner()
+ {
+ return createRunner(
+ "tpchstandard",
+ new NativeExecutionModule(
+ Optional.of(new NativeExecutionConnectorConfig().setConnectorName("tpch"))));
+ }
+
+ public static PrestoSparkQueryRunner createRunner(String defaultCatalog, Optional baseDir, Map additionalConfigProperties, Map additionalSparkProperties, ImmutableList nativeModules)
+ {
+ ImmutableMap.Builder configBuilder = ImmutableMap.builder();
+ configBuilder.putAll(getNativeWorkerSystemProperties()).putAll(additionalConfigProperties);
+ Optional dataDir = baseDir.map(path -> Paths.get(path.toString() + '/' + DEFAULT_STORAGE_FORMAT));
+ PrestoSparkQueryRunner queryRunner = new PrestoSparkQueryRunner(
+ defaultCatalog,
+ configBuilder.build(),
+ getNativeWorkerHiveProperties(),
+ additionalSparkProperties,
+ dataDir,
+ nativeModules,
+ AVAILABLE_CPU_COUNT);
+
+ ExtendedHiveMetastore metastore = queryRunner.getMetastore();
+ if (!metastore.getDatabase(METASTORE_CONTEXT, "tpch").isPresent()) {
+ metastore.createDatabase(METASTORE_CONTEXT, createDatabaseMetastoreObject("tpch"));
+ }
+ return queryRunner;
+ }
+
+ public static QueryRunner createJavaQueryRunner()
+ throws Exception
+ {
+ return PrestoNativeQueryRunnerUtils.javaHiveQueryRunnerBuilder()
+ .setAddStorageFormatToPath(true)
+ .setStorageFormat(DEFAULT_STORAGE_FORMAT)
+ .build();
+ }
+
+ public static void customizeLogging()
+ {
+ Logging logging = Logging.initialize();
+ logging.setLevel("org.apache.spark", WARN);
+ logging.setLevel("com.facebook.presto.spark", WARN);
+ }
+
+ private static Database createDatabaseMetastoreObject(String name)
+ {
+ return Database.builder()
+ .setDatabaseName(name)
+ .setOwnerName("public")
+ .setOwnerType(PrincipalType.ROLE)
+ .build();
+ }
+
+ private static Map getNativeExecutionShuffleConfigs()
+ {
+ ImmutableMap.Builder sparkConfigs = ImmutableMap.builder();
+ sparkConfigs.put(SPARK_SHUFFLE_MANAGER, "com.facebook.presto.spark.classloader_interface.PrestoSparkNativeExecutionShuffleManager");
+ sparkConfigs.put(FALLBACK_SPARK_SHUFFLE_MANAGER, "org.apache.spark.shuffle.sort.SortShuffleManager");
+ return sparkConfigs.build();
+ }
+
+ public static synchronized Path getBaseDataPath()
+ {
+ if (dataDirectory.isPresent()) {
+ return dataDirectory.get();
+ }
+
+ Optional dataDirectoryStr = getProperty("DATA_DIR");
+ if (!dataDirectoryStr.isPresent()) {
+ try {
+ dataDirectory = Optional.of(createTempDirectory("PrestoTest").toAbsolutePath());
+ }
+ catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ else {
+ dataDirectory = Optional.of(getNativeQueryRunnerParameters().dataDirectory);
+ }
+ return dataDirectory.get();
+ }
+
+ // This is a temporary replacement for the function from
+ // HiveTestUtils.getProperty. But HiveTestUtils instantiation seems to be
+ // failing due to missing info in Hdfs setup. Until we fix that, this is a
+ // copy of that function. As its a simple utility function, we are ok to punt
+ // fixing this
+ // TODO: Use HiveTestUtils.getProperty and delete this function
+ public static Optional getProperty(String name)
+ {
+ String systemPropertyValue = System.getProperty(name);
+ String environmentVariableValue = System.getenv(name);
+ if (systemPropertyValue == null) {
+ if (environmentVariableValue == null) {
+ return Optional.empty();
+ }
+ else {
+ return Optional.of(environmentVariableValue);
+ }
+ }
+ else {
+ if (environmentVariableValue != null && !systemPropertyValue.equals(environmentVariableValue)) {
+ throw new IllegalArgumentException(format("%s is set in both Java system property and environment variable, but their values are different. The Java system property value is %s, while the" +
+ " environment variable value is %s. Please use only one value.",
+ name,
+ systemPropertyValue,
+ environmentVariableValue));
+ }
+ return Optional.of(systemPropertyValue);
+ }
+ }
+}
diff --git a/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkExpressionCompiler.java b/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkExpressionCompiler.java
new file mode 100644
index 0000000000000..9c854ad3782ed
--- /dev/null
+++ b/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkExpressionCompiler.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.spark;
+
+import com.facebook.presto.nativeworker.AbstractTestExpressionCompiler;
+import com.facebook.presto.testing.QueryRunner;
+
+public class TestPrestoSparkExpressionCompiler
+ extends AbstractTestExpressionCompiler
+{
+ @Override
+ protected QueryRunner getQueryRunner()
+ {
+ return PrestoSparkNativeQueryRunnerUtils.createHiveRunner();
+ }
+}
diff --git a/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkNativeAggregations.java b/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkNativeAggregations.java
new file mode 100644
index 0000000000000..80da866e9102e
--- /dev/null
+++ b/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkNativeAggregations.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.spark;
+
+import com.facebook.presto.nativeworker.AbstractTestNativeAggregations;
+import com.facebook.presto.testing.ExpectedQueryRunner;
+import com.facebook.presto.testing.QueryRunner;
+
+public class TestPrestoSparkNativeAggregations
+ extends AbstractTestNativeAggregations
+{
+ @Override
+ protected QueryRunner createQueryRunner()
+ {
+ return PrestoSparkNativeQueryRunnerUtils.createHiveRunner();
+ }
+
+ @Override
+ protected ExpectedQueryRunner createExpectedQueryRunner()
+ throws Exception
+ {
+ return PrestoSparkNativeQueryRunnerUtils.createJavaQueryRunner();
+ }
+}
diff --git a/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkNativeArrayFunctionQueries.java b/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkNativeArrayFunctionQueries.java
new file mode 100644
index 0000000000000..40db300489142
--- /dev/null
+++ b/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkNativeArrayFunctionQueries.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.spark;
+
+import com.facebook.presto.nativeworker.AbstractTestNativeArrayFunctionQueries;
+import com.facebook.presto.testing.ExpectedQueryRunner;
+import com.facebook.presto.testing.QueryRunner;
+
+public class TestPrestoSparkNativeArrayFunctionQueries
+ extends AbstractTestNativeArrayFunctionQueries
+{
+ @Override
+ protected QueryRunner createQueryRunner()
+ {
+ return PrestoSparkNativeQueryRunnerUtils.createHiveRunner();
+ }
+
+ @Override
+ protected ExpectedQueryRunner createExpectedQueryRunner()
+ throws Exception
+ {
+ return PrestoSparkNativeQueryRunnerUtils.createJavaQueryRunner();
+ }
+}
diff --git a/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkNativeBitwiseFunctionQueries.java b/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkNativeBitwiseFunctionQueries.java
new file mode 100644
index 0000000000000..1f1f010a3dafe
--- /dev/null
+++ b/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkNativeBitwiseFunctionQueries.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.spark;
+
+import com.facebook.presto.nativeworker.AbstractTestNativeBitwiseFunctionQueries;
+import com.facebook.presto.testing.ExpectedQueryRunner;
+import com.facebook.presto.testing.QueryRunner;
+
+public class TestPrestoSparkNativeBitwiseFunctionQueries
+ extends AbstractTestNativeBitwiseFunctionQueries
+{
+ @Override
+ protected QueryRunner createQueryRunner()
+ {
+ return PrestoSparkNativeQueryRunnerUtils.createHiveRunner();
+ }
+
+ @Override
+ protected ExpectedQueryRunner createExpectedQueryRunner()
+ throws Exception
+ {
+ return PrestoSparkNativeQueryRunnerUtils.createJavaQueryRunner();
+ }
+}
diff --git a/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkNativeGeneralQueries.java b/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkNativeGeneralQueries.java
new file mode 100644
index 0000000000000..3160b0b3d69c6
--- /dev/null
+++ b/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkNativeGeneralQueries.java
@@ -0,0 +1,130 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.spark;
+
+import com.facebook.presto.nativeworker.AbstractTestNativeGeneralQueries;
+import com.facebook.presto.testing.ExpectedQueryRunner;
+import com.facebook.presto.testing.QueryRunner;
+import org.testng.annotations.Ignore;
+
+public class TestPrestoSparkNativeGeneralQueries
+ extends AbstractTestNativeGeneralQueries
+{
+ @Override
+ protected QueryRunner createQueryRunner()
+ {
+ return PrestoSparkNativeQueryRunnerUtils.createHiveRunner();
+ }
+
+ @Override
+ protected ExpectedQueryRunner createExpectedQueryRunner()
+ throws Exception
+ {
+ return PrestoSparkNativeQueryRunnerUtils.createJavaQueryRunner();
+ }
+
+ // TODO: Enable following Ignored tests after fixing (Tests can be enabled by removing the method)
+ @Override
+ @Ignore
+ public void testCatalogWithCacheEnabled() {}
+
+ @Override
+ @Ignore
+ public void testAnalyzeStatsOnDecimals() {}
+
+ // VeloxUserError: Unsupported file format in TableWrite: "ORC".
+ @Override
+ @Ignore
+ public void testColumnFilter() {}
+
+ // VeloxUserError: Unsupported file format in TableWrite: "ORC".
+ @Override
+ @Ignore
+ public void testIPAddressIPPrefix() {}
+
+ // VeloxUserError: Unsupported file format in TableWrite: "ORC".
+ @Override
+ @Ignore
+ public void testInvalidUuid() {}
+
+ // VeloxUserError: Unsupported file format in TableWrite: "ORC".
+ @Override
+ @Ignore
+ public void testStringFunctions() {}
+
+ // VeloxUserError: Unsupported file format in TableWrite: "ORC".
+ @Override
+ @Ignore
+ public void testUuid() {}
+
+ // Access Denied: Cannot set catalog session property
+ // hive.parquet_pushdown_filter_enabled
+ @Override
+ @Ignore
+ public void testDecimalApproximateAggregates() {}
+
+ // Access Denied: Cannot set catalog session property
+ // hive.parquet_pushdown_filter_enabled
+ @Override
+ @Ignore
+ public void testDecimalRangeFilters() {}
+
+ // Access Denied: Cannot set catalog session property
+ // hive.pushdown_filter_enabled
+ @Override
+ @Ignore
+ public void testTimestampWithTimeZone() {}
+
+ @Override
+ @Ignore
+ public void testDistributedSortSingleNode() {}
+
+ //VeloxRuntimeError: ReaderFactory is not registered for format text
+ @Override
+ @Ignore
+ public void testReadTableWithTextfileFormat() {}
+
+ @Override
+ @Ignore
+ public void testInformationSchemaTables() {}
+
+ @Override
+ @Ignore
+ public void testShowAndDescribe() {}
+
+ @Override
+ public void testSystemTables() {}
+
+ // @TODO Refer https://github.com/prestodb/presto/issues/20294
+ @Override
+ @Ignore
+ public void testAnalyzeStats() {}
+
+ // https://github.com/prestodb/presto/issues/22275
+ @Override
+ @Ignore
+ public void testUnionAllInsert() {}
+
+ @Override
+ @Ignore
+ public void testShowSessionWithoutJavaSessionProperties() {}
+
+ @Override
+ @Ignore
+ public void testSetSessionJavaWorkerSessionProperty() {}
+
+ @Override
+ @Ignore
+ public void testRowWiseExchange() {}
+}
diff --git a/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkNativeJoinQueries.java b/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkNativeJoinQueries.java
new file mode 100644
index 0000000000000..b0b29d813f5a4
--- /dev/null
+++ b/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkNativeJoinQueries.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.spark;
+
+import com.facebook.presto.nativeworker.AbstractTestNativeJoinQueries;
+import com.facebook.presto.testing.ExpectedQueryRunner;
+import com.facebook.presto.testing.QueryRunner;
+
+public class TestPrestoSparkNativeJoinQueries
+ extends AbstractTestNativeJoinQueries
+{
+ @Override
+ protected QueryRunner createQueryRunner()
+ {
+ return PrestoSparkNativeQueryRunnerUtils.createHiveRunner();
+ }
+
+ @Override
+ protected ExpectedQueryRunner createExpectedQueryRunner()
+ throws Exception
+ {
+ return PrestoSparkNativeQueryRunnerUtils.createJavaQueryRunner();
+ }
+
+ @Override
+ public Object[][] joinTypeProviderImpl()
+ {
+ return new Object[][] {{partitionedJoin()}};
+ }
+}
diff --git a/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkNativeSimpleQueries.java b/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkNativeSimpleQueries.java
new file mode 100644
index 0000000000000..834814f66d75d
--- /dev/null
+++ b/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkNativeSimpleQueries.java
@@ -0,0 +1,196 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.spark;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.testing.ExpectedQueryRunner;
+import com.facebook.presto.testing.QueryRunner;
+import com.facebook.presto.tests.AbstractTestQueryFramework;
+import org.testng.annotations.Ignore;
+import org.testng.annotations.Test;
+
+import static com.facebook.presto.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE;
+import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createBucketedCustomer;
+import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createBucketedLineitemAndOrders;
+import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createCustomer;
+import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createEmptyTable;
+import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createLineitem;
+import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createNation;
+import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createOrders;
+import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createOrdersEx;
+import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createOrdersHll;
+import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createPart;
+import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createPartSupp;
+import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createPartitionedNation;
+import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createPrestoBenchTables;
+import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createRegion;
+import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createSupplier;
+import static com.facebook.presto.spark.PrestoSparkSessionProperties.SPARK_BROADCAST_JOIN_MAX_MEMORY_OVERRIDE;
+import static com.facebook.presto.spark.PrestoSparkSessionProperties.SPARK_RETRY_ON_OUT_OF_MEMORY_BROADCAST_JOIN_ENABLED;
+
+public class TestPrestoSparkNativeSimpleQueries
+ extends AbstractTestQueryFramework
+{
+ @Override
+ protected void createTables()
+ {
+ QueryRunner queryRunner = (QueryRunner) getExpectedQueryRunner();
+ createLineitem(queryRunner);
+ createOrders(queryRunner);
+ createOrdersHll(queryRunner);
+ createOrdersEx(queryRunner);
+ createNation(queryRunner);
+ createRegion(queryRunner);
+ createPartitionedNation(queryRunner);
+ createBucketedCustomer(queryRunner);
+ createCustomer(queryRunner);
+ createPart(queryRunner);
+ createPartSupp(queryRunner);
+ createRegion(queryRunner);
+ createSupplier(queryRunner);
+ createEmptyTable(queryRunner);
+ createPrestoBenchTables(queryRunner);
+ createBucketedLineitemAndOrders(queryRunner);
+ }
+
+ @Override
+ protected QueryRunner createQueryRunner()
+ {
+ return PrestoSparkNativeQueryRunnerUtils.createHiveRunner();
+ }
+
+ @Override
+ protected ExpectedQueryRunner createExpectedQueryRunner()
+ throws Exception
+ {
+ return PrestoSparkNativeQueryRunnerUtils.createJavaQueryRunner();
+ }
+
+ @Test
+ public void testMapOnlyQueries()
+ {
+ assertQuery("SELECT * FROM orders");
+ assertQuery("SELECT orderkey, custkey FROM orders WHERE orderkey <= 200");
+ assertQuery("SELECT nullif(orderkey, custkey) FROM orders");
+ assertQuery("SELECT orderkey, custkey FROM orders ORDER BY orderkey LIMIT 4");
+ }
+
+ @Test
+ public void testAggregations()
+ {
+ assertQuery("SELECT count(*) c FROM lineitem WHERE partkey % 10 = 1 GROUP BY partkey");
+ }
+
+ @Test
+ public void testJoins()
+ {
+ assertQuery("SELECT * FROM orders o, lineitem l WHERE o.orderkey = l.orderkey AND o.orderkey % 2 = 1");
+ }
+
+ @Test
+ public void testFailures()
+ {
+ assertQueryFails("SELECT orderkey / 0 FROM orders", ".*division by zero.*");
+ }
+
+ /**
+ * Test native execution of cpp functions declared via a json file.
+ * `eq()` Scalar function & `sum()` Aggregate function are defined in `src/test/resources/external_functions.json`
+ */
+ @Test
+ @Ignore("json schema based external function registraion is failing. Fix it and re-enable this test")
+ public void testJsonFileBasedFunction()
+ {
+ assertQuery("SELECT json.test_schema.eq(1, linenumber) FROM lineitem", "SELECT 1 = linenumber FROM lineitem");
+ assertQuery("SELECT json.test_schema.sum(linenumber) FROM lineitem", "SELECT sum(linenumber) FROM lineitem");
+
+ // Test functions with complex types (array, map, and row).
+ assertQuery("SELECT json.test_schema.array_constructor(linenumber) FROM lineitem", "SELECT array_constructor(linenumber) FROM lineitem");
+
+ assertQuery("SELECT json.test_schema.map(json.test_schema.array_constructor(linenumber), json.test_schema.array_constructor(linenumber)) FROM lineitem", "SELECT map(array_constructor(linenumber), array_constructor(linenumber)) FROM lineitem");
+ assertQuery("SELECT json.test_schema.map_entries(json.test_schema.map(json.test_schema.array_constructor(linenumber), json.test_schema.array_constructor(linenumber))) FROM lineitem", "SELECT map_entries(map(array_constructor(linenumber), array_constructor(linenumber))) FROM lineitem");
+ }
+
+ /**
+ * Test aggregation using companion functions with partial and final aggregation steps handled by separate queries.
+ * The first query computes partial aggregation states and stores them in the avg_partial_states table.
+ * Subsequent queries read from avg_partial_states and aggregate the states to the final result.
+ */
+ @Test
+ @Ignore("json schema based external function registraion is failing. Fix it and re-enable this test")
+ public void testAggregationCompanionFunction()
+ {
+ Session session = Session.builder(getSession())
+ .setCatalogSessionProperty("hive", "collect_column_statistics_on_write", "false")
+ .setCatalogSessionProperty("hive", "orc_compression_codec", "ZSTD")
+ .build();
+ try {
+ getQueryRunner().execute(session,
+ "CREATE TABLE avg_partial_states AS ( "
+ + "SELECT orderpriority, cast(json.test_schema.avg_partial(shippriority) as ROW(sum DOUBLE, count BIGINT)) as states "
+ + "FROM orders "
+ + "GROUP BY orderstatus, orderpriority "
+ + ")");
+
+ // Test group-by aggregation.
+ assertQuery(
+ "SELECT orderpriority, json.test_schema.avg_merge_extract_double(states) FROM avg_partial_states GROUP BY orderpriority",
+ "SELECT orderpriority, avg(shippriority) FROM orders GROUP BY orderpriority");
+ assertQuery(
+ "SELECT orderpriority, json.test_schema.avg_extract_double(json.test_schema.avg_merge(states)) FROM avg_partial_states GROUP BY orderpriority",
+ "SELECT orderpriority, avg(shippriority) FROM orders GROUP BY orderpriority");
+
+ // Test global aggregation.
+ assertQuery(
+ "SELECT json.test_schema.avg_merge_extract_double(states) FROM avg_partial_states",
+ "SELECT avg(shippriority) FROM orders");
+ assertQuery(
+ "SELECT json.test_schema.avg_extract_double(json.test_schema.avg_merge(states)) FROM avg_partial_states",
+ "SELECT avg(shippriority) FROM orders");
+ }
+ finally {
+ getQueryRunner().execute("DROP TABLE IF EXISTS avg_partial_states");
+ }
+ }
+
+ @Test
+ public void testRetryOnOutOfMemoryBroadcastJoin()
+ {
+ String query = "select l.orderkey from lineitem l join orders o on l.orderkey = o.orderkey ";
+
+ Session session = getSessionWithBroadcastJoinDistribution("10B", false);
+ // Query should fail with broadcast join OOM & retry disabled.
+ assertQueryFails(
+ session,
+ query,
+ "Query exceeded per-node broadcast memory limit of 10B \\[Max serialized broadcast size: .*kB\\]");
+
+ Session expectedSession = Session.builder(getSession())
+ .setSystemProperty(JOIN_DISTRIBUTION_TYPE, "BROADCAST")
+ .build();
+ Session actualSession = getSessionWithBroadcastJoinDistribution("10B", true);
+
+ // Query should succeed with broadcast join OOM & retry enabled.
+ assertQuery(actualSession, query, expectedSession, query);
+ }
+
+ private Session getSessionWithBroadcastJoinDistribution(String broadcastJoinMaxMemory, Boolean retryOnBroadcastOutOfMemory)
+ {
+ return Session.builder(getSession())
+ .setSystemProperty(JOIN_DISTRIBUTION_TYPE, "BROADCAST")
+ .setSystemProperty(SPARK_BROADCAST_JOIN_MAX_MEMORY_OVERRIDE, broadcastJoinMaxMemory)
+ .setSystemProperty(SPARK_RETRY_ON_OUT_OF_MEMORY_BROADCAST_JOIN_ENABLED, Boolean.toString(retryOnBroadcastOutOfMemory))
+ .build();
+ }
+}
diff --git a/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkNativeTpchConnectorQueries.java b/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkNativeTpchConnectorQueries.java
new file mode 100644
index 0000000000000..b2a0bb1ad5d51
--- /dev/null
+++ b/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkNativeTpchConnectorQueries.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.spark;
+
+import com.facebook.presto.nativeworker.AbstractTestNativeTpchConnectorQueries;
+import com.facebook.presto.testing.ExpectedQueryRunner;
+import com.facebook.presto.testing.QueryRunner;
+import org.testng.annotations.Ignore;
+
+public class TestPrestoSparkNativeTpchConnectorQueries
+ extends AbstractTestNativeTpchConnectorQueries
+{
+ @Override
+ protected QueryRunner createQueryRunner()
+ {
+ return PrestoSparkNativeQueryRunnerUtils.createTpchRunner();
+ }
+
+ @Override
+ protected ExpectedQueryRunner createExpectedQueryRunner()
+ throws Exception
+ {
+ return PrestoSparkNativeQueryRunnerUtils.createJavaQueryRunner();
+ }
+
+ @Override
+ public void testMissingTpchConnector()
+ {
+ super.testMissingTpchConnector(".*Catalog tpch does not exist*");
+ }
+
+ @Override
+ @Ignore
+ public void testTpchTinyTables() {}
+}
diff --git a/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkNativeTpchQueries.java b/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkNativeTpchQueries.java
new file mode 100644
index 0000000000000..82f2be5d033f0
--- /dev/null
+++ b/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkNativeTpchQueries.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.spark;
+
+import com.facebook.presto.nativeworker.AbstractTestNativeTpchQueries;
+import com.facebook.presto.testing.ExpectedQueryRunner;
+import com.facebook.presto.testing.QueryRunner;
+import org.testng.annotations.Ignore;
+
+public class TestPrestoSparkNativeTpchQueries
+ extends AbstractTestNativeTpchQueries
+{
+ @Override
+ protected QueryRunner createQueryRunner()
+ {
+ return PrestoSparkNativeQueryRunnerUtils.createHiveRunner();
+ }
+
+ @Override
+ protected ExpectedQueryRunner createExpectedQueryRunner()
+ throws Exception
+ {
+ return PrestoSparkNativeQueryRunnerUtils.createJavaQueryRunner();
+ }
+
+ // TODO: Enable following Ignored tests after fixing (Tests can be enabled by removing the method)
+ // Following tests require broadcast join
+ @Override
+ @Ignore
+ public void testTpchQ7() {}
+
+ @Override
+ @Ignore
+ public void testTpchQ8() {}
+
+ @Override
+ @Ignore
+ public void testTpchQ11() {}
+
+ @Override
+ @Ignore
+ public void testTpchQ15() {}
+
+ @Override
+ @Ignore
+ public void testTpchQ18() {}
+
+ @Override
+ @Ignore
+ public void testTpchQ21() {}
+
+ @Override
+ @Ignore
+ public void testTpchQ22() {}
+}
diff --git a/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkNativeWindowQueries.java b/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkNativeWindowQueries.java
new file mode 100644
index 0000000000000..bdd2dff3009fe
--- /dev/null
+++ b/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkNativeWindowQueries.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.spark;
+
+import com.facebook.presto.nativeworker.AbstractTestNativeWindowQueries;
+import com.facebook.presto.testing.ExpectedQueryRunner;
+import com.facebook.presto.testing.QueryRunner;
+
+public class TestPrestoSparkNativeWindowQueries
+ extends AbstractTestNativeWindowQueries
+{
+ @Override
+ protected QueryRunner createQueryRunner()
+ {
+ return PrestoSparkNativeQueryRunnerUtils.createHiveRunner();
+ }
+
+ @Override
+ protected ExpectedQueryRunner createExpectedQueryRunner()
+ throws Exception
+ {
+ return PrestoSparkNativeQueryRunnerUtils.createJavaQueryRunner();
+ }
+}
diff --git a/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkSqlFunctions.java b/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkSqlFunctions.java
new file mode 100644
index 0000000000000..a9de4232645ba
--- /dev/null
+++ b/presto-native-execution/src/test/java/com/facebook/presto/spark/TestPrestoSparkSqlFunctions.java
@@ -0,0 +1,26 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.spark;
+
+import com.facebook.presto.operator.scalar.TestCustomFunctions;
+import com.facebook.presto.sql.analyzer.FeaturesConfig;
+
+public class TestPrestoSparkSqlFunctions
+ extends TestCustomFunctions
+{
+ public TestPrestoSparkSqlFunctions()
+ {
+ super(new FeaturesConfig().setNativeExecutionEnabled(true));
+ }
+}
diff --git a/presto-spark-base/pom.xml b/presto-spark-base/pom.xml
index f87490571a6d9..205a7ed872252 100644
--- a/presto-spark-base/pom.xml
+++ b/presto-spark-base/pom.xml
@@ -43,8 +43,8 @@
- com.facebook.presto.spark
- spark-core
+ org.apache.spark
+ spark-core_2.13
provided
@@ -68,7 +68,10 @@
com.facebook.presto
presto-main-base
-
+
+ com.facebook.presto
+ presto-main
+
com.facebook.presto
presto-expressions
@@ -104,6 +107,11 @@
json
+
+ com.facebook.airlift
+ http-client
+
+
com.fasterxml.jackson.core
jackson-annotations
diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java
index f41f6365e1d37..e197d4a867804 100644
--- a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java
+++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java
@@ -126,6 +126,7 @@
import com.facebook.presto.spark.execution.property.NativeExecutionVeloxConfig;
import com.facebook.presto.spark.execution.shuffle.PrestoSparkLocalShuffleReadInfo;
import com.facebook.presto.spark.execution.shuffle.PrestoSparkLocalShuffleWriteInfo;
+import com.facebook.presto.spark.execution.task.PrestoSparkNativeTaskExecutorFactory;
import com.facebook.presto.spark.execution.task.PrestoSparkTaskExecutorFactory;
import com.facebook.presto.spark.node.PrestoSparkInternalNodeManager;
import com.facebook.presto.spark.node.PrestoSparkNodePartitioningManager;
@@ -557,6 +558,7 @@ protected void setup(Binder binder)
binder.bind(PrestoSparkAccessControlChecker.class).in(Scopes.SINGLETON);
binder.bind(PrestoSparkPlanFragmenter.class).in(Scopes.SINGLETON);
binder.bind(PrestoSparkRddFactory.class).in(Scopes.SINGLETON);
+ binder.bind(PrestoSparkNativeTaskExecutorFactory.class).in(Scopes.SINGLETON);
binder.bind(PrestoSparkTaskExecutorFactory.class).in(Scopes.SINGLETON);
binder.bind(PrestoSparkQueryExecutionFactory.class).in(Scopes.SINGLETON);
binder.bind(PrestoSparkService.class).in(Scopes.SINGLETON);
diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkService.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkService.java
index 623010e073e5c..418d8dc7225f0 100644
--- a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkService.java
+++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkService.java
@@ -17,6 +17,7 @@
import com.facebook.presto.spark.classloader_interface.IPrestoSparkQueryExecutionFactory;
import com.facebook.presto.spark.classloader_interface.IPrestoSparkService;
import com.facebook.presto.spark.classloader_interface.IPrestoSparkTaskExecutorFactory;
+import com.facebook.presto.spark.execution.task.PrestoSparkNativeTaskExecutorFactory;
import com.facebook.presto.spark.execution.task.PrestoSparkTaskExecutorFactory;
import javax.inject.Inject;
@@ -28,16 +29,19 @@ public class PrestoSparkService
{
private final PrestoSparkQueryExecutionFactory queryExecutionFactory;
private final PrestoSparkTaskExecutorFactory taskExecutorFactory;
+ private final PrestoSparkNativeTaskExecutorFactory prestoSparkNativeTaskExecutorFactory;
private final LifeCycleManager lifeCycleManager;
@Inject
public PrestoSparkService(
PrestoSparkQueryExecutionFactory queryExecutionFactory,
PrestoSparkTaskExecutorFactory taskExecutorFactory,
+ PrestoSparkNativeTaskExecutorFactory prestoSparkNativeTaskExecutorFactory,
LifeCycleManager lifeCycleManager)
{
this.queryExecutionFactory = requireNonNull(queryExecutionFactory, "queryExecutionFactory is null");
this.taskExecutorFactory = requireNonNull(taskExecutorFactory, "taskExecutorFactory is null");
+ this.prestoSparkNativeTaskExecutorFactory = requireNonNull(prestoSparkNativeTaskExecutorFactory, "prestoSparkNativeTaskExecutorFactory is null");
this.lifeCycleManager = requireNonNull(lifeCycleManager, "lifeCycleManager is null");
}
@@ -53,10 +57,17 @@ public IPrestoSparkTaskExecutorFactory getTaskExecutorFactory()
return taskExecutorFactory;
}
+ @Override
+ public IPrestoSparkTaskExecutorFactory getNativeTaskExecutorFactory()
+ {
+ return prestoSparkNativeTaskExecutorFactory;
+ }
+
@Override
public void close()
{
lifeCycleManager.stop();
+ prestoSparkNativeTaskExecutorFactory.close();
taskExecutorFactory.close();
}
}
diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkServiceFactory.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkServiceFactory.java
index 4a2aec585fd6b..25bb881296c3b 100644
--- a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkServiceFactory.java
+++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkServiceFactory.java
@@ -19,6 +19,7 @@
import com.facebook.presto.spark.classloader_interface.PrestoSparkBootstrapTimer;
import com.facebook.presto.spark.classloader_interface.PrestoSparkConfiguration;
import com.facebook.presto.spark.classloader_interface.SparkProcessType;
+import com.facebook.presto.spark.execution.nativeprocess.NativeExecutionModule;
import com.facebook.presto.sql.parser.SqlParserOptions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
@@ -66,7 +67,9 @@ protected List getAdditionalModules(PrestoSparkConfiguration configurati
{
checkArgument(METADATA_STORAGE_TYPE_LOCAL.equalsIgnoreCase(configuration.getMetadataStorageType()), "only local metadata storage is supported");
return ImmutableList.of(
- new PrestoSparkLocalMetadataStorageModule());
+ new PrestoSparkLocalMetadataStorageModule(),
+ // TODO: Need to let NativeExecutionModule addition be controlled by configuration as well.
+ new NativeExecutionModule());
}
protected SqlParserOptions getSqlParserOptions()
diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/http/PrestoSparkHttpServerClient.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/http/PrestoSparkHttpServerClient.java
new file mode 100644
index 0000000000000..219b012f5244c
--- /dev/null
+++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/http/PrestoSparkHttpServerClient.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.spark.execution.http;
+
+import com.facebook.airlift.http.client.HttpClient;
+import com.facebook.airlift.http.client.Request;
+import com.facebook.airlift.json.JsonCodec;
+import com.facebook.airlift.log.Logger;
+import com.facebook.presto.client.ServerInfo;
+import com.facebook.presto.server.smile.BaseResponse;
+import com.google.common.util.concurrent.ListenableFuture;
+
+import javax.annotation.concurrent.ThreadSafe;
+
+import java.net.URI;
+
+import static com.facebook.airlift.http.client.HttpUriBuilder.uriBuilderFrom;
+import static com.facebook.airlift.http.client.Request.Builder.prepareGet;
+import static com.facebook.presto.server.RequestHelpers.setContentTypeHeaders;
+import static com.facebook.presto.server.smile.AdaptingJsonResponseHandler.createAdaptingJsonResponseHandler;
+import static java.util.Objects.requireNonNull;
+
+/**
+ * An abstraction of HTTP client that communicates with the locally running Presto worker process. It exposes worker's server level endpoints to simple method calls.
+ */
+@ThreadSafe
+public class PrestoSparkHttpServerClient
+{
+ private static final Logger log = Logger.get(PrestoSparkHttpServerClient.class);
+ private static final String SERVER_URI = "/v1/info";
+
+ private final HttpClient httpClient;
+ private final URI location;
+ private final URI serverUri;
+ private final JsonCodec serverInfoCodec;
+
+ public PrestoSparkHttpServerClient(
+ HttpClient httpClient,
+ URI location,
+ JsonCodec serverInfoCodec)
+ {
+ this.httpClient = requireNonNull(httpClient, "httpClient is null");
+ this.location = requireNonNull(location, "location is null");
+ this.serverInfoCodec = requireNonNull(serverInfoCodec, "serverInfoCodec is null");
+ this.serverUri = getServerUri(location);
+ }
+
+ public ListenableFuture> getServerInfo()
+ {
+ Request request = setContentTypeHeaders(false, prepareGet())
+ .setUri(serverUri)
+ .build();
+ return httpClient.executeAsync(request, createAdaptingJsonResponseHandler(serverInfoCodec));
+ }
+
+ public URI getLocation()
+ {
+ return location;
+ }
+
+ private URI getServerUri(URI baseUri)
+ {
+ return uriBuilderFrom(baseUri)
+ .appendPath(SERVER_URI)
+ .build();
+ }
+}
diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/http/PrestoSparkHttpTaskClient.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/http/PrestoSparkHttpTaskClient.java
new file mode 100644
index 0000000000000..a013632ec06cd
--- /dev/null
+++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/http/PrestoSparkHttpTaskClient.java
@@ -0,0 +1,455 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.spark.execution.http;
+
+import com.facebook.airlift.http.client.HeaderName;
+import com.facebook.airlift.http.client.HttpClient;
+import com.facebook.airlift.http.client.Request;
+import com.facebook.airlift.http.client.Response;
+import com.facebook.airlift.http.client.ResponseHandler;
+import com.facebook.airlift.json.JsonCodec;
+import com.facebook.presto.Session;
+import com.facebook.presto.execution.TaskId;
+import com.facebook.presto.execution.TaskInfo;
+import com.facebook.presto.execution.TaskSource;
+import com.facebook.presto.execution.buffer.OutputBuffers;
+import com.facebook.presto.execution.scheduler.TableWriteInfo;
+import com.facebook.presto.operator.HttpRpcShuffleClient.PageResponseHandler;
+import com.facebook.presto.operator.PageBufferClient.PagesResponse;
+import com.facebook.presto.server.RequestErrorTracker;
+import com.facebook.presto.server.SimpleHttpResponseCallback;
+import com.facebook.presto.server.SimpleHttpResponseHandler;
+import com.facebook.presto.server.SimpleHttpResponseHandlerStats;
+import com.facebook.presto.server.TaskUpdateRequest;
+import com.facebook.presto.server.smile.BaseResponse;
+import com.facebook.presto.spi.PrestoException;
+import com.facebook.presto.sql.planner.PlanFragment;
+import com.google.common.collect.ImmutableListMultimap;
+import com.google.common.collect.ListMultimap;
+import com.google.common.io.ByteStreams;
+import com.google.common.util.concurrent.FutureCallback;
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.SettableFuture;
+import io.airlift.units.DataSize;
+import io.airlift.units.Duration;
+
+import javax.annotation.concurrent.ThreadSafe;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.net.URI;
+import java.util.List;
+import java.util.Optional;
+import java.util.concurrent.Executor;
+import java.util.concurrent.ScheduledExecutorService;
+
+import static com.facebook.airlift.concurrent.MoreFutures.getFutureValue;
+import static com.facebook.airlift.http.client.HttpUriBuilder.uriBuilderFrom;
+import static com.facebook.airlift.http.client.Request.Builder.prepareDelete;
+import static com.facebook.airlift.http.client.Request.Builder.prepareGet;
+import static com.facebook.airlift.http.client.Request.Builder.preparePost;
+import static com.facebook.airlift.http.client.ResponseHandlerUtils.propagate;
+import static com.facebook.airlift.http.client.StaticBodyGenerator.createStaticBodyGenerator;
+import static com.facebook.presto.client.PrestoHeaders.PRESTO_MAX_SIZE;
+import static com.facebook.presto.client.PrestoHeaders.PRESTO_MAX_WAIT;
+import static com.facebook.presto.server.RequestHelpers.setContentTypeHeaders;
+import static com.facebook.presto.server.smile.AdaptingJsonResponseHandler.createAdaptingJsonResponseHandler;
+import static com.facebook.presto.spi.StandardErrorCode.NATIVE_EXECUTION_TASK_ERROR;
+import static com.facebook.presto.spi.StandardErrorCode.REMOTE_TASK_ERROR;
+import static com.google.common.util.concurrent.Futures.addCallback;
+import static com.google.common.util.concurrent.Futures.transformAsync;
+import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
+import static java.util.Objects.requireNonNull;
+
+/**
+ * An abstraction of HTTP client that communicates with the locally running Presto worker process. It exposes worker endpoints to simple method calls.
+ */
+@ThreadSafe
+public class PrestoSparkHttpTaskClient
+{
+ private static final String TASK_URI = "/v1/task/";
+
+ private final HttpClient httpClient;
+ private final URI location;
+ private final URI taskUri;
+ private final JsonCodec taskInfoCodec;
+ private final JsonCodec planFragmentCodec;
+ private final JsonCodec taskUpdateRequestCodec;
+ private final Duration infoRefreshMaxWait;
+ private final Executor executor;
+ private final ScheduledExecutorService scheduledExecutorService;
+ private final Duration remoteTaskMaxErrorDuration;
+
+ public PrestoSparkHttpTaskClient(
+ HttpClient httpClient,
+ TaskId taskId,
+ URI location,
+ JsonCodec taskInfoCodec,
+ JsonCodec planFragmentCodec,
+ JsonCodec taskUpdateRequestCodec,
+ Duration infoRefreshMaxWait,
+ Executor executor,
+ ScheduledExecutorService scheduledExecutorService,
+ Duration remoteTaskMaxErrorDuration)
+ {
+ this.httpClient = requireNonNull(httpClient, "httpClient is null");
+ this.location = requireNonNull(location, "location is null");
+ this.taskInfoCodec = requireNonNull(taskInfoCodec, "taskInfoCodec is null");
+ this.planFragmentCodec = requireNonNull(planFragmentCodec, "planFragmentCodec is null");
+ this.taskUpdateRequestCodec = requireNonNull(taskUpdateRequestCodec, "taskUpdateRequestCodec is null");
+ this.taskUri = createTaskUri(location, taskId);
+ this.infoRefreshMaxWait = requireNonNull(infoRefreshMaxWait, "infoRefreshMaxWait is null");
+ this.executor = requireNonNull(executor, "executor is null");
+ this.scheduledExecutorService = requireNonNull(scheduledExecutorService, "scheduledExecutorService is null");
+ this.remoteTaskMaxErrorDuration = requireNonNull(remoteTaskMaxErrorDuration, "remoteTaskMaxErrorDuration is null");
+ }
+
+ /**
+ * Get results from a native engine task that ends with none shuffle operator. It always fetches from a single buffer.
+ */
+ public ListenableFuture getResults(long token, DataSize maxResponseSize)
+ {
+ RequestErrorTracker errorTracker = new RequestErrorTracker(
+ "NativeExecution",
+ location,
+ NATIVE_EXECUTION_TASK_ERROR,
+ "getResults encountered too many errors talking to native process",
+ remoteTaskMaxErrorDuration,
+ scheduledExecutorService,
+ "sending update request to native process");
+ SettableFuture result = SettableFuture.create();
+ scheduleGetResultsRequest(prepareGetResultsRequest(token, maxResponseSize), errorTracker, result);
+ return result;
+ }
+
+ private void scheduleGetResultsRequest(
+ Request request,
+ RequestErrorTracker errorTracker,
+ SettableFuture result)
+ {
+ ListenableFuture responseFuture = transformAsync(
+ errorTracker.acquireRequestPermit(),
+ ignored -> {
+ errorTracker.startRequest();
+ return httpClient.executeAsync(request, new PageResponseHandler());
+ },
+ executor);
+ addCallback(responseFuture, new FutureCallback()
+ {
+ @Override
+ public void onSuccess(PagesResponse response)
+ {
+ errorTracker.requestSucceeded();
+ result.set(response);
+ }
+
+ @Override
+ public void onFailure(Throwable failure)
+ {
+ if (failure instanceof PrestoException) {
+ // do not retry on PrestoException
+ result.setException(failure);
+ return;
+ }
+ try {
+ errorTracker.requestFailed(failure);
+ scheduleGetResultsRequest(request, errorTracker, result);
+ }
+ catch (Throwable t) {
+ result.setException(t);
+ }
+ }
+ }, executor);
+ }
+
+ private Request prepareGetResultsRequest(long token, DataSize maxResponseSize)
+ {
+ return prepareGet()
+ .setHeader(PRESTO_MAX_SIZE, maxResponseSize.toString())
+ .setUri(uriBuilderFrom(taskUri)
+ .appendPath("/results/0")
+ .appendPath(String.valueOf(token))
+ .build())
+ .build();
+ }
+
+ public void acknowledgeResultsAsync(long nextToken)
+ {
+ URI uri = uriBuilderFrom(taskUri)
+ .appendPath("/results/0")
+ .appendPath(String.valueOf(nextToken))
+ .appendPath("acknowledge")
+ .build();
+ Request request = prepareGet().setUri(uri).build();
+ executeWithRetries("acknowledgeResults", "acknowledge task results are received", request, new BytesResponseHandler());
+ }
+
+ public ListenableFuture abortResultsAsync()
+ {
+ Request request = prepareDelete().setUri(
+ uriBuilderFrom(taskUri)
+ .appendPath("/results/0")
+ .build())
+ .build();
+ return asVoidFuture(executeWithRetries("abortResults", "abort task results", request, new BytesResponseHandler()));
+ }
+
+ private static ListenableFuture asVoidFuture(ListenableFuture> future)
+ {
+ return Futures.transform(future, (ignored) -> null, directExecutor());
+ }
+
+ public TaskInfo getTaskInfo()
+ {
+ Request request = setContentTypeHeaders(false, prepareGet())
+ .setHeader(PRESTO_MAX_WAIT, infoRefreshMaxWait.toString())
+ .setUri(taskUri)
+ .build();
+ ListenableFuture future = executeWithRetries(
+ "getTaskInfo",
+ "get remote task info",
+ request,
+ createAdaptingJsonResponseHandler(taskInfoCodec));
+ return getFutureValue(future);
+ }
+
+ public TaskInfo updateTask(
+ List sources,
+ PlanFragment planFragment,
+ TableWriteInfo tableWriteInfo,
+ Optional shuffleWriteInfo,
+ Optional broadcastBasePath,
+ Session session,
+ OutputBuffers outputBuffers)
+ {
+ Optional fragment = Optional.of(planFragment.bytesForTaskSerialization(planFragmentCodec));
+ Optional writeInfo = Optional.of(tableWriteInfo);
+ TaskUpdateRequest updateRequest = new TaskUpdateRequest(
+ session.toSessionRepresentation(),
+ session.getIdentity().getExtraCredentials(),
+ fragment,
+ sources,
+ outputBuffers,
+ writeInfo);
+ BatchTaskUpdateRequest batchTaskUpdateRequest = new BatchTaskUpdateRequest(updateRequest, shuffleWriteInfo, broadcastBasePath);
+
+ Request request = setContentTypeHeaders(false, preparePost())
+ .setUri(uriBuilderFrom(taskUri)
+ .appendPath("batch")
+ .build())
+ .setBodyGenerator(createStaticBodyGenerator(taskUpdateRequestCodec.toBytes(batchTaskUpdateRequest)))
+ .build();
+ ListenableFuture future = executeWithRetries(
+ "updateTask",
+ "create or update remote task",
+ request,
+ createAdaptingJsonResponseHandler(taskInfoCodec));
+ return getFutureValue(future);
+ }
+
+ public URI getLocation()
+ {
+ return location;
+ }
+
+ public URI getTaskUri()
+ {
+ return taskUri;
+ }
+
+ private URI createTaskUri(URI baseUri, TaskId taskId)
+ {
+ return uriBuilderFrom(baseUri)
+ .appendPath(TASK_URI)
+ .appendPath(taskId.toString())
+ .build();
+ }
+
+ private ListenableFuture executeWithRetries(
+ String name,
+ String description,
+ Request request,
+ ResponseHandler, RuntimeException> responseHandler)
+ {
+ RequestErrorTracker errorTracker = new RequestErrorTracker(
+ "NativeExecution",
+ location,
+ NATIVE_EXECUTION_TASK_ERROR,
+ name + " encountered too many errors talking to native process",
+ remoteTaskMaxErrorDuration,
+ scheduledExecutorService,
+ description);
+ SettableFuture result = SettableFuture.create();
+ scheduleRequest(request, responseHandler, errorTracker, result);
+ return result;
+ }
+
+ private void scheduleRequest(
+ Request request,
+ ResponseHandler, RuntimeException> responseHandler,
+ RequestErrorTracker errorTracker,
+ SettableFuture result)
+ {
+ ListenableFuture> responseFuture = transformAsync(
+ errorTracker.acquireRequestPermit(),
+ ignored -> {
+ errorTracker.startRequest();
+ return httpClient.executeAsync(request, responseHandler);
+ },
+ executor);
+ SimpleHttpResponseCallback callback = new SimpleHttpResponseCallback()
+ {
+ @Override
+ public void success(T value)
+ {
+ result.set(value);
+ }
+
+ @Override
+ public void failed(Throwable failure)
+ {
+ if (failure instanceof PrestoException) {
+ // do not retry on PrestoException
+ result.setException(failure);
+ return;
+ }
+ try {
+ errorTracker.requestFailed(failure);
+ scheduleRequest(request, responseHandler, errorTracker, result);
+ }
+ catch (Throwable t) {
+ result.setException(t);
+ }
+ }
+
+ @Override
+ public void fatal(Throwable cause)
+ {
+ result.setException(cause);
+ }
+ };
+ addCallback(
+ responseFuture,
+ new SimpleHttpResponseHandler<>(
+ callback,
+ location,
+ new SimpleHttpResponseHandlerStats(),
+ REMOTE_TASK_ERROR),
+ executor);
+ }
+
+ private static class BytesResponseHandler
+ implements ResponseHandler, RuntimeException>
+ {
+ @Override
+ public BaseResponse handleException(Request request, Exception exception)
+ {
+ throw propagate(request, exception);
+ }
+
+ @Override
+ public BaseResponse handle(Request request, Response response)
+ {
+ return new BytesResponse(
+ response.getStatusCode(),
+ response.getHeaders(),
+ readResponseBytes(response));
+ }
+
+ private static byte[] readResponseBytes(Response response)
+ {
+ try {
+ InputStream inputStream = response.getInputStream();
+ if (inputStream == null) {
+ return new byte[] {};
+ }
+ return ByteStreams.toByteArray(inputStream);
+ }
+ catch (IOException e) {
+ throw new RuntimeException("Error reading response from server", e);
+ }
+ }
+ }
+
+ private static class BytesResponse
+ implements BaseResponse
+ {
+ private final int statusCode;
+ private final ListMultimap headers;
+ private final byte[] bytes;
+
+ public BytesResponse(int statusCode, ListMultimap headers, byte[] bytes)
+ {
+ this.statusCode = statusCode;
+ this.headers = ImmutableListMultimap.copyOf(requireNonNull(headers, "headers is null"));
+ this.bytes = bytes;
+ }
+
+ @Override
+ public int getStatusCode()
+ {
+ return statusCode;
+ }
+
+ @Override
+ public String getHeader(String name)
+ {
+ List values = getHeaders().get(HeaderName.of(name));
+ return values.isEmpty() ? null : values.get(0);
+ }
+
+ @Override
+ public List getHeaders(String name)
+ {
+ return headers.get(HeaderName.of(name));
+ }
+
+ @Override
+ public ListMultimap getHeaders()
+ {
+ return headers;
+ }
+
+ @Override
+ public boolean hasValue()
+ {
+ return true;
+ }
+
+ @Override
+ public byte[] getValue()
+ {
+ return bytes;
+ }
+
+ @Override
+ public int getResponseSize()
+ {
+ return bytes.length;
+ }
+
+ @Override
+ public byte[] getResponseBytes()
+ {
+ return bytes;
+ }
+
+ @Override
+ public Exception getException()
+ {
+ return null;
+ }
+ }
+}
diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/nativeprocess/DetachedNativeExecutionProcess.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/nativeprocess/DetachedNativeExecutionProcess.java
new file mode 100644
index 0000000000000..8fce1cd3cc276
--- /dev/null
+++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/nativeprocess/DetachedNativeExecutionProcess.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.spark.execution.nativeprocess;
+
+import com.facebook.airlift.http.client.HttpClient;
+import com.facebook.airlift.json.JsonCodec;
+import com.facebook.airlift.log.Logger;
+import com.facebook.presto.Session;
+import com.facebook.presto.client.ServerInfo;
+import com.facebook.presto.spark.execution.property.WorkerProperty;
+import io.airlift.units.Duration;
+
+import java.io.IOException;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.ScheduledExecutorService;
+
+import static java.util.Objects.requireNonNull;
+
+/**
+ * This is a testing class that essentially does nothing. Its mere purpose is to disable the launching and killing of
+ * native process by native execution. Instead it allows the native execution to reuse the same externally launched
+ * process over and over again.
+ */
+public class DetachedNativeExecutionProcess
+ extends NativeExecutionProcess
+{
+ private static final Logger log = Logger.get(DetachedNativeExecutionProcess.class);
+
+ public DetachedNativeExecutionProcess(
+ String executablePath,
+ String programArguments,
+ Session session,
+ HttpClient httpClient,
+ ExecutorService executorService,
+ ScheduledExecutorService errorRetryScheduledExecutor,
+ JsonCodec serverInfoCodec,
+ Duration maxErrorDuration,
+ WorkerProperty, ?, ?, ?> workerProperty)
+ throws IOException
+ {
+ super(executablePath,
+ programArguments,
+ session,
+ httpClient,
+ executorService,
+ errorRetryScheduledExecutor,
+ serverInfoCodec,
+ maxErrorDuration,
+ workerProperty);
+ }
+
+ @Override
+ public void start()
+ throws ExecutionException, InterruptedException
+ {
+ log.info("Please use port " + getPort() + " for detached native process launching.");
+ // getServerInfoWithRetry will return a Future on the getting the ServerInfo from the native process, we
+ // intentionally block on the Future till the native process successfully response the ServerInfo to ensure the
+ // process has been launched and initialized correctly.
+ getServerInfoWithRetry().get();
+ }
+
+ /**
+ * The port Spark native is going to use instead of dynamically generate. Since this class is for local debugging
+ * only, there is no need to make this port configurable.
+ *
+ * @return a fixed port.
+ */
+ @Override
+ public int getPort()
+ {
+ String configuredPort = requireNonNull(System.getProperty("NATIVE_PORT"), "NATIVE_PORT not set for interactive debugging");
+ return Integer.valueOf(configuredPort);
+ }
+}
diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/nativeprocess/DetachedNativeExecutionProcessFactory.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/nativeprocess/DetachedNativeExecutionProcessFactory.java
new file mode 100644
index 0000000000000..30727ba857aa2
--- /dev/null
+++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/nativeprocess/DetachedNativeExecutionProcessFactory.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.spark.execution.nativeprocess;
+
+import com.facebook.airlift.http.client.HttpClient;
+import com.facebook.airlift.json.JsonCodec;
+import com.facebook.presto.Session;
+import com.facebook.presto.client.ServerInfo;
+import com.facebook.presto.spark.execution.property.WorkerProperty;
+import com.facebook.presto.spark.execution.task.ForNativeExecutionTask;
+import com.facebook.presto.spi.PrestoException;
+import com.facebook.presto.sql.analyzer.FeaturesConfig;
+import com.google.inject.Inject;
+import io.airlift.units.Duration;
+
+import java.io.IOException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+
+import static com.facebook.presto.spi.StandardErrorCode.NATIVE_EXECUTION_PROCESS_LAUNCH_ERROR;
+import static java.lang.String.format;
+import static java.util.Objects.requireNonNull;
+
+public class DetachedNativeExecutionProcessFactory
+ extends NativeExecutionProcessFactory
+{
+ private final HttpClient httpClient;
+ private final ExecutorService coreExecutor;
+ private final ScheduledExecutorService errorRetryScheduledExecutor;
+ private final JsonCodec serverInfoCodec;
+ private final WorkerProperty, ?, ?, ?> workerProperty;
+
+ @Inject
+ public DetachedNativeExecutionProcessFactory(
+ @ForNativeExecutionTask HttpClient httpClient,
+ ExecutorService coreExecutor,
+ ScheduledExecutorService errorRetryScheduledExecutor,
+ JsonCodec serverInfoCodec,
+ WorkerProperty, ?, ?, ?> workerProperty,
+ FeaturesConfig featuresConfig)
+ {
+ super(httpClient, coreExecutor, errorRetryScheduledExecutor, serverInfoCodec, workerProperty, featuresConfig);
+ this.httpClient = requireNonNull(httpClient, "httpClient is null");
+ this.coreExecutor = requireNonNull(coreExecutor, "ecoreExecutor is null");
+ this.errorRetryScheduledExecutor = requireNonNull(errorRetryScheduledExecutor, "errorRetryScheduledExecutor is null");
+ this.serverInfoCodec = requireNonNull(serverInfoCodec, "serverInfoCodec is null");
+ this.workerProperty = requireNonNull(workerProperty, "workerProperty is null");
+ }
+
+ @Override
+ public NativeExecutionProcess getNativeExecutionProcess(Session session)
+ {
+ return createNativeExecutionProcess(session, new Duration(2, TimeUnit.MINUTES));
+ }
+
+ @Override
+ public NativeExecutionProcess createNativeExecutionProcess(Session session, Duration maxErrorDuration)
+ {
+ try {
+ return new DetachedNativeExecutionProcess(
+ getExecutablePath(),
+ getProgramArguments(),
+ session,
+ httpClient,
+ coreExecutor,
+ errorRetryScheduledExecutor,
+ serverInfoCodec,
+ maxErrorDuration,
+ workerProperty);
+ }
+ catch (IOException e) {
+ throw new PrestoException(NATIVE_EXECUTION_PROCESS_LAUNCH_ERROR, format("Cannot start native process: %s", e.getMessage()), e);
+ }
+ }
+}
diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/nativeprocess/HttpNativeExecutionTaskInfoFetcher.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/nativeprocess/HttpNativeExecutionTaskInfoFetcher.java
new file mode 100644
index 0000000000000..2f77b0f787832
--- /dev/null
+++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/nativeprocess/HttpNativeExecutionTaskInfoFetcher.java
@@ -0,0 +1,125 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.spark.execution.nativeprocess;
+
+import com.facebook.airlift.log.Logger;
+import com.facebook.presto.execution.TaskInfo;
+import com.facebook.presto.spark.execution.http.PrestoSparkHttpTaskClient;
+import com.google.common.annotations.VisibleForTesting;
+import io.airlift.units.Duration;
+
+import javax.annotation.concurrent.GuardedBy;
+
+import java.util.Optional;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.ScheduledFuture;
+import java.util.concurrent.atomic.AtomicReference;
+
+import static com.google.common.base.Throwables.throwIfUnchecked;
+import static java.util.Objects.requireNonNull;
+
+/**
+ * This class helps to fetch {@link TaskInfo} for a native task through HTTP communications with a Presto worker. Upon calling start(), object of this class will continuously poll
+ * {@link TaskInfo} from Presto worker and update its internal {@link TaskInfo} buffer. Caller is responsible for retrieving updated {@link TaskInfo} by calling getTaskInfo()
+ * method.
+ * Caller is also responsible for calling stop() to release resource when this fetcher is no longer needed.
+ */
+public class HttpNativeExecutionTaskInfoFetcher
+{
+ private static final Logger log = Logger.get(HttpNativeExecutionTaskInfoFetcher.class);
+
+ private final PrestoSparkHttpTaskClient workerClient;
+ private final ScheduledExecutorService updateScheduledExecutor;
+ private final AtomicReference taskInfo = new AtomicReference<>();
+ private final Duration infoFetchInterval;
+ private final AtomicReference lastException = new AtomicReference<>();
+ private final Object taskFinished;
+
+ @GuardedBy("this")
+ private ScheduledFuture> scheduledFuture;
+
+ public HttpNativeExecutionTaskInfoFetcher(
+ ScheduledExecutorService updateScheduledExecutor,
+ PrestoSparkHttpTaskClient workerClient,
+ Duration infoFetchInterval,
+ Object taskFinished)
+ {
+ this.workerClient = requireNonNull(workerClient, "workerClient is null");
+ this.updateScheduledExecutor = requireNonNull(updateScheduledExecutor, "updateScheduledExecutor is null");
+ this.infoFetchInterval = requireNonNull(infoFetchInterval, "infoFetchInterval is null");
+ this.taskFinished = requireNonNull(taskFinished, "taskFinished is null");
+ }
+
+ public void start()
+ {
+ scheduledFuture = updateScheduledExecutor.scheduleWithFixedDelay(
+ this::doGetTaskInfo, 0, (long) infoFetchInterval.getValue(), infoFetchInterval.getUnit());
+ }
+
+ public void stop()
+ {
+ if (scheduledFuture != null) {
+ scheduledFuture.cancel(false);
+ }
+ }
+
+ @VisibleForTesting
+ void doGetTaskInfo()
+ {
+ try {
+ TaskInfo result = workerClient.getTaskInfo();
+ onSuccess(result);
+ }
+ catch (Throwable t) {
+ onFailure(t);
+ }
+ }
+
+ private void onSuccess(TaskInfo result)
+ {
+ log.debug("TaskInfoCallback success %s", result.getTaskId());
+ taskInfo.set(result);
+ if (result.getTaskStatus().getState().isDone()) {
+ synchronized (taskFinished) {
+ taskFinished.notifyAll();
+ }
+ }
+ }
+
+ private void onFailure(Throwable failure)
+ {
+ stop();
+ lastException.set(failure);
+ synchronized (taskFinished) {
+ taskFinished.notifyAll();
+ }
+ }
+
+ public Optional getTaskInfo()
+ throws RuntimeException
+ {
+ if (scheduledFuture != null && scheduledFuture.isCancelled() && lastException.get() != null) {
+ Throwable failure = lastException.get();
+ throwIfUnchecked(failure);
+ throw new RuntimeException(failure);
+ }
+ TaskInfo info = taskInfo.get();
+ return info == null ? Optional.empty() : Optional.of(info);
+ }
+
+ public AtomicReference getLastException()
+ {
+ return lastException;
+ }
+}
diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/nativeprocess/HttpNativeExecutionTaskResultFetcher.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/nativeprocess/HttpNativeExecutionTaskResultFetcher.java
new file mode 100644
index 0000000000000..f94674948b04a
--- /dev/null
+++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/nativeprocess/HttpNativeExecutionTaskResultFetcher.java
@@ -0,0 +1,194 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.spark.execution.nativeprocess;
+
+import com.facebook.airlift.log.Logger;
+import com.facebook.presto.operator.PageBufferClient;
+import com.facebook.presto.spark.execution.http.PrestoSparkHttpTaskClient;
+import com.facebook.presto.spi.HostAddress;
+import com.facebook.presto.spi.PrestoException;
+import com.facebook.presto.spi.page.SerializedPage;
+import io.airlift.units.DataSize;
+import io.airlift.units.Duration;
+
+import java.util.List;
+import java.util.Optional;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.LinkedBlockingDeque;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.ScheduledFuture;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.concurrent.atomic.AtomicReference;
+
+import static com.facebook.airlift.concurrent.MoreFutures.getFutureValue;
+import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
+import static com.facebook.presto.spi.StandardErrorCode.SERIALIZED_PAGE_CHECKSUM_ERROR;
+import static com.facebook.presto.spi.page.PagesSerdeUtil.isChecksumValid;
+import static com.google.common.base.Throwables.throwIfUnchecked;
+import static java.lang.String.format;
+import static java.util.Objects.requireNonNull;
+
+/**
+ * This class helps to fetch results for a native task through HTTP communications with a Presto worker. The object of this class will give back a {@link CompletableFuture} to the
+ * caller upon start(). This future will be completed when retrievals of all results by the fetcher is completed. Results are retrieved and stored in an internal buffer, which is
+ * supposed to be polled by the caller. Note that the completion of the future does not mean all results have been consumed by the caller. The caller is responsible for making sure
+ * all results be consumed after future completion.
+ * There is a capacity cap (MAX_BUFFER_SIZE) for internal buffer managed by {@link HttpNativeExecutionTaskResultFetcher}. The fetcher will stop fetching results when buffer limit
+ * is hit and resume fetching after some of the buffer has been consumed, bringing buffer size down below the limit.
+ *
+ * The fetcher specifically serves to fetch table write commit metadata results from Presto worker so currently no shuffle result fetching is supported.
+ */
+public class HttpNativeExecutionTaskResultFetcher
+{
+ private static final Logger log = Logger.get(HttpNativeExecutionTaskResultFetcher.class);
+ private static final Duration FETCH_INTERVAL = new Duration(200, TimeUnit.MILLISECONDS);
+ private static final Duration POLL_TIMEOUT = new Duration(100, TimeUnit.MILLISECONDS);
+ private static final DataSize MAX_RESPONSE_SIZE = new DataSize(32, DataSize.Unit.MEGABYTE);
+ private static final DataSize MAX_BUFFER_SIZE = new DataSize(128, DataSize.Unit.MEGABYTE);
+
+ private final ScheduledExecutorService scheduler;
+ private final PrestoSparkHttpTaskClient workerClient;
+ private final LinkedBlockingDeque pageBuffer = new LinkedBlockingDeque<>();
+ private final AtomicLong bufferMemoryBytes;
+ private final Object taskHasResult;
+ private final AtomicReference lastException = new AtomicReference<>();
+
+ private ScheduledFuture> scheduledFuture;
+
+ private long token;
+
+ public HttpNativeExecutionTaskResultFetcher(
+ ScheduledExecutorService scheduler,
+ PrestoSparkHttpTaskClient workerClient,
+ Object taskHasResult)
+ {
+ this.scheduler = requireNonNull(scheduler, "scheduler is null");
+ this.workerClient = requireNonNull(workerClient, "workerClient is null");
+ this.bufferMemoryBytes = new AtomicLong();
+ this.taskHasResult = requireNonNull(taskHasResult, "taskHasResult is null");
+ }
+
+ public void start()
+ {
+ scheduledFuture = scheduler.scheduleAtFixedRate(this::doGetResults,
+ 0,
+ (long) FETCH_INTERVAL.getValue(),
+ FETCH_INTERVAL.getUnit());
+ }
+
+ public void stop(boolean success)
+ {
+ if (scheduledFuture != null) {
+ scheduledFuture.cancel(false);
+ }
+
+ if (success && !pageBuffer.isEmpty()) {
+ throw new PrestoException(GENERIC_INTERNAL_ERROR, format("TaskResultFetcher is closed with %s pages left in the buffer", pageBuffer.size()));
+ }
+ }
+
+ /**
+ * Blocking call to poll from result buffer. Blocks until content becomes
+ * available in the buffer, or until timeout is hit.
+ *
+ * @return the first {@link SerializedPage} result buffer contains.
+ */
+ public Optional pollPage()
+ throws InterruptedException
+ {
+ throwIfFailed();
+ SerializedPage page = pageBuffer.poll((long) POLL_TIMEOUT.getValue(), POLL_TIMEOUT.getUnit());
+ if (page != null) {
+ bufferMemoryBytes.addAndGet(-page.getSizeInBytes());
+ return Optional.of(page);
+ }
+ return Optional.empty();
+ }
+
+ public boolean hasPage()
+ {
+ throwIfFailed();
+ return !pageBuffer.isEmpty();
+ }
+
+ private void throwIfFailed()
+ {
+ if (scheduledFuture != null && scheduledFuture.isCancelled() && lastException.get() != null) {
+ Throwable failure = lastException.get();
+ throwIfUnchecked(failure);
+ throw new RuntimeException(failure);
+ }
+ }
+
+ private void doGetResults()
+ {
+ if (bufferMemoryBytes.longValue() >= MAX_BUFFER_SIZE.toBytes()) {
+ return;
+ }
+
+ try {
+ PageBufferClient.PagesResponse pagesResponse = getFutureValue(workerClient.getResults(token, MAX_RESPONSE_SIZE));
+ onSuccess(pagesResponse);
+ }
+ catch (Throwable t) {
+ onFailure(t);
+ }
+ }
+
+ private void onSuccess(PageBufferClient.PagesResponse pagesResponse)
+ {
+ List pages = pagesResponse.getPages();
+ long bytes = 0;
+ long positionCount = 0;
+ for (SerializedPage page : pages) {
+ if (!isChecksumValid(page)) {
+ throw new PrestoException(
+ SERIALIZED_PAGE_CHECKSUM_ERROR,
+ format("Received corrupted serialized page from host %s",
+ HostAddress.fromUri(workerClient.getLocation())));
+ }
+ bytes += page.getSizeInBytes();
+ positionCount += page.getPositionCount();
+ }
+ log.info("Received %s rows in %s pages from %s", positionCount, pages.size(), workerClient.getTaskUri());
+
+ pageBuffer.addAll(pages);
+ bufferMemoryBytes.addAndGet(bytes);
+ long nextToken = pagesResponse.getNextToken();
+ if (pages.size() > 0) {
+ workerClient.acknowledgeResultsAsync(nextToken);
+ }
+ token = nextToken;
+ if (pagesResponse.isClientComplete()) {
+ workerClient.abortResultsAsync();
+ scheduledFuture.cancel(false);
+ }
+ if (!pages.isEmpty()) {
+ synchronized (taskHasResult) {
+ taskHasResult.notifyAll();
+ }
+ }
+ }
+
+ private void onFailure(Throwable t)
+ {
+ workerClient.abortResultsAsync();
+ stop(false);
+ lastException.set(t);
+ synchronized (taskHasResult) {
+ taskHasResult.notifyAll();
+ }
+ }
+}
diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/nativeprocess/NativeExecutionModule.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/nativeprocess/NativeExecutionModule.java
new file mode 100644
index 0000000000000..a1544c2e8de51
--- /dev/null
+++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/nativeprocess/NativeExecutionModule.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.spark.execution.nativeprocess;
+
+import com.facebook.presto.spark.execution.property.NativeExecutionConnectorConfig;
+import com.facebook.presto.spark.execution.property.NativeExecutionNodeConfig;
+import com.facebook.presto.spark.execution.property.NativeExecutionSystemConfig;
+import com.facebook.presto.spark.execution.property.NativeExecutionVeloxConfig;
+import com.facebook.presto.spark.execution.property.PrestoSparkWorkerProperty;
+import com.facebook.presto.spark.execution.property.WorkerProperty;
+import com.facebook.presto.spark.execution.shuffle.PrestoSparkLocalShuffleInfoTranslator;
+import com.facebook.presto.spark.execution.shuffle.PrestoSparkShuffleInfoTranslator;
+import com.facebook.presto.spark.execution.task.ForNativeExecutionTask;
+import com.facebook.presto.spark.execution.task.NativeExecutionTaskFactory;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.inject.Binder;
+import com.google.inject.Module;
+import com.google.inject.Scopes;
+import com.google.inject.TypeLiteral;
+import io.airlift.units.Duration;
+
+import java.util.Optional;
+
+import static com.facebook.airlift.http.client.HttpClientBinder.httpClientBinder;
+import static com.google.inject.multibindings.OptionalBinder.newOptionalBinder;
+import static java.util.concurrent.TimeUnit.SECONDS;
+
+public class NativeExecutionModule
+ implements Module
+{
+ private Optional connectorConfig;
+
+ // For use by production system where the configurations can only be tuned via configurations.
+ public NativeExecutionModule()
+ {
+ this.connectorConfig = Optional.empty();
+ }
+
+ // In the future, we would make more bindings injected into NativeExecutionModule
+ // to be able to test various configuration parameters
+ @VisibleForTesting
+ public NativeExecutionModule(Optional connectorConfig)
+ {
+ this.connectorConfig = connectorConfig;
+ }
+
+ @Override
+ public void configure(Binder binder)
+ {
+ bindWorkerProperties(binder);
+ bindNativeExecutionTaskFactory(binder);
+ bindHttpClient(binder);
+ bindNativeExecutionProcess(binder);
+ bindShuffle(binder);
+ }
+
+ protected void bindShuffle(Binder binder)
+ {
+ binder.bind(PrestoSparkLocalShuffleInfoTranslator.class).in(Scopes.SINGLETON);
+ newOptionalBinder(binder, new TypeLiteral() {}).setDefault().to(PrestoSparkLocalShuffleInfoTranslator.class).in(Scopes.SINGLETON);
+ }
+
+ protected void bindWorkerProperties(Binder binder)
+ {
+ newOptionalBinder(binder, new TypeLiteral>() {}).setDefault().to(PrestoSparkWorkerProperty.class).in(Scopes.SINGLETON);
+ if (connectorConfig.isPresent()) {
+ binder.bind(PrestoSparkWorkerProperty.class).toInstance(new PrestoSparkWorkerProperty(connectorConfig.get(), new NativeExecutionNodeConfig(), new NativeExecutionSystemConfig(), new NativeExecutionVeloxConfig()));
+ }
+ else {
+ binder.bind(PrestoSparkWorkerProperty.class).in(Scopes.SINGLETON);
+ }
+ }
+
+ protected void bindHttpClient(Binder binder)
+ {
+ httpClientBinder(binder)
+ .bindHttpClient("nativeExecution", ForNativeExecutionTask.class)
+ .withConfigDefaults(config -> {
+ config.setRequestTimeout(new Duration(10, SECONDS));
+ config.setMaxConnectionsPerServer(250);
+ });
+ }
+
+ protected void bindNativeExecutionTaskFactory(Binder binder)
+ {
+ binder.bind(NativeExecutionTaskFactory.class).in(Scopes.SINGLETON);
+ }
+
+ protected void bindNativeExecutionProcess(Binder binder)
+ {
+ if (System.getProperty("NATIVE_PORT") != null) {
+ binder.bind(NativeExecutionProcessFactory.class).to(DetachedNativeExecutionProcessFactory.class).in(Scopes.SINGLETON);
+ }
+ else {
+ binder.bind(NativeExecutionProcessFactory.class).in(Scopes.SINGLETON);
+ }
+ }
+}
diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/nativeprocess/NativeExecutionProcess.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/nativeprocess/NativeExecutionProcess.java
new file mode 100644
index 0000000000000..055da7b5cd65a
--- /dev/null
+++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/nativeprocess/NativeExecutionProcess.java
@@ -0,0 +1,494 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.spark.execution.nativeprocess;
+
+import com.facebook.airlift.http.client.HttpClient;
+import com.facebook.airlift.json.JsonCodec;
+import com.facebook.airlift.log.Logger;
+import com.facebook.presto.Session;
+import com.facebook.presto.client.ServerInfo;
+import com.facebook.presto.server.RequestErrorTracker;
+import com.facebook.presto.server.smile.BaseResponse;
+import com.facebook.presto.spark.classloader_interface.PrestoSparkFatalException;
+import com.facebook.presto.spark.execution.http.PrestoSparkHttpServerClient;
+import com.facebook.presto.spark.execution.property.WorkerProperty;
+import com.facebook.presto.spi.PrestoException;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.ImmutableList;
+import com.google.common.util.concurrent.FutureCallback;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.SettableFuture;
+import io.airlift.units.Duration;
+import org.apache.spark.SparkEnv$;
+import org.apache.spark.SparkFiles;
+
+import javax.annotation.Nullable;
+
+import java.io.BufferedReader;
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.FileDescriptor;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.io.OutputStream;
+import java.io.OutputStreamWriter;
+import java.lang.reflect.Field;
+import java.net.InetSocketAddress;
+import java.net.ServerSocket;
+import java.net.URI;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.List;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Executor;
+import java.util.concurrent.RejectedExecutionException;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+import static com.facebook.airlift.http.client.HttpStatus.OK;
+import static com.facebook.airlift.http.client.HttpUriBuilder.uriBuilder;
+import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
+import static com.facebook.presto.spi.StandardErrorCode.NATIVE_EXECUTION_BINARY_NOT_EXIST;
+import static com.facebook.presto.spi.StandardErrorCode.NATIVE_EXECUTION_PROCESS_LAUNCH_ERROR;
+import static com.facebook.presto.spi.StandardErrorCode.NATIVE_EXECUTION_TASK_ERROR;
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.util.concurrent.Futures.addCallback;
+import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
+import static java.lang.String.format;
+import static java.nio.charset.StandardCharsets.UTF_8;
+import static java.util.Objects.requireNonNull;
+import static java.util.concurrent.TimeUnit.MILLISECONDS;
+
+public class NativeExecutionProcess
+ implements AutoCloseable
+{
+ private static final Logger log = Logger.get(NativeExecutionProcess.class);
+ private static final String NATIVE_EXECUTION_TASK_ERROR_MESSAGE = "Native process launch failed with multiple retries.";
+ private static final String WORKER_CONFIG_FILE = "/config.properties";
+ private static final String WORKER_NODE_CONFIG_FILE = "/node.properties";
+ private static final String WORKER_CONNECTOR_CONFIG_FILE = "/catalog/";
+ private static final int SIGSYS = 31;
+
+ private final String executablePath;
+ private final String programArguments;
+ private final Session session;
+ private final PrestoSparkHttpServerClient serverClient;
+ private final URI location;
+ private final int port;
+ private final Executor executor;
+ private final RequestErrorTracker errorTracker;
+ private final HttpClient httpClient;
+ private final WorkerProperty, ?, ?, ?> workerProperty;
+
+ private volatile Process process;
+ private volatile ProcessOutputPipe processOutputPipe;
+
+ public NativeExecutionProcess(
+ String executablePath,
+ String programArguments,
+ Session session,
+ HttpClient httpClient,
+ Executor executor,
+ ScheduledExecutorService scheduledExecutorService,
+ JsonCodec serverInfoCodec,
+ Duration maxErrorDuration,
+ WorkerProperty, ?, ?, ?> workerProperty)
+ throws IOException
+ {
+ this.executablePath = requireNonNull(executablePath, "executablePath is null");
+ this.programArguments = requireNonNull(programArguments, "programArguments is null");
+ String nodeInternalAddress = workerProperty.getNodeConfig().getNodeInternalAddress();
+ this.port = getAvailableTcpPort(nodeInternalAddress);
+ this.session = requireNonNull(session, "session is null");
+ this.location = uriBuilder()
+ .scheme("http")
+ .host(nodeInternalAddress)
+ .port(getPort())
+ .build();
+ this.httpClient = requireNonNull(httpClient, "httpClient is null");
+ this.serverClient = new PrestoSparkHttpServerClient(
+ this.httpClient,
+ location,
+ serverInfoCodec);
+ this.executor = requireNonNull(executor, "executor is null");
+ this.errorTracker = new RequestErrorTracker(
+ "NativeExecution",
+ location,
+ NATIVE_EXECUTION_TASK_ERROR,
+ NATIVE_EXECUTION_TASK_ERROR_MESSAGE,
+ maxErrorDuration,
+ scheduledExecutorService,
+ "getting native process status");
+ this.workerProperty = requireNonNull(workerProperty, "workerProperty is null");
+ }
+
+ /**
+ * Starts the external native execution process. The method will be blocked by connecting to the native process's /v1/info endpoint with backoff retries until timeout.
+ */
+ public synchronized void start()
+ throws ExecutionException, InterruptedException, IOException
+ {
+ if (process != null) {
+ return;
+ }
+
+ ProcessBuilder processBuilder = new ProcessBuilder(getLaunchCommand());
+ processBuilder.redirectOutput(ProcessBuilder.Redirect.INHERIT);
+ processBuilder.environment().put("INIT_PRESTO_QUERY_ID", session.getQueryId().toString());
+ try {
+ process = processBuilder.start();
+ processOutputPipe = new ProcessOutputPipe(
+ getPid(process),
+ process.getErrorStream(),
+ new FileOutputStream(FileDescriptor.err));
+ processOutputPipe.start();
+ }
+ catch (IOException e) {
+ log.error(format("Cannot start %s, error message: %s", processBuilder.command(), e.getMessage()));
+ throw new PrestoException(NATIVE_EXECUTION_PROCESS_LAUNCH_ERROR, format("Cannot start %s", processBuilder.command()), e);
+ }
+
+ // getServerInfoWithRetry will return a Future on the getting the ServerInfo from the native process, we intentionally block on the Future till
+ // the native process successfully response the ServerInfo to ensure the process has been launched and initialized correctly.
+ try {
+ getServerInfoWithRetry().get();
+ }
+ catch (Throwable t) {
+ close();
+ // If the native process launch failed, it usually indicates the current host machine is overloaded, we need to throw a fatal error (PrestoSparkFatalException is a
+ // subclass of fatal error VirtualMachineError)to let Spark shutdown current executor and fail over to another one (Here is the definition of scala fatal error Spark
+ // is relying on: https://www.scala-lang.org/api/2.13.3/scala/util/control/NonFatal$.html)
+ throw new PrestoSparkFatalException(t.getMessage(), t.getCause());
+ }
+ }
+
+ @VisibleForTesting
+ public SettableFuture getServerInfoWithRetry()
+ {
+ SettableFuture future = SettableFuture.create();
+ doGetServerInfo(future);
+ return future;
+ }
+
+ /**
+ * Triggers coredump (also terminates the process)
+ */
+ public void terminateWithCore(Duration timeout)
+ {
+ // chosen as the least likely core signal to occur naturally (invalid sys call)
+ // https://man7.org/linux/man-pages/man7/signal.7.html
+ Process process = sendSignal(SIGSYS);
+ if (process == null) {
+ return;
+ }
+ try {
+ long pid = getPid(process);
+ log.info("Waiting %s for process %s to terminate", timeout, pid);
+ if (!process.waitFor(timeout.toMillis(), MILLISECONDS)) {
+ log.warn("Process %s did not terminate within %s", pid, timeout);
+ process.destroyForcibly();
+ }
+ else {
+ log.info("Process %s successfully terminated with status code %s", pid, process.exitValue());
+ }
+ }
+ catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw new RuntimeException(e);
+ }
+ }
+
+ private Process sendSignal(int signal)
+ {
+ Process process = this.process;
+ if (process == null) {
+ log.warn("Failure sending signal, process does not exist");
+ return null;
+ }
+ long pid = getPid(process);
+ if (!process.isAlive()) {
+ log.warn("Failure sending signal, process is dead: %s", pid);
+ return null;
+ }
+ try {
+ log.info("Sending signal to process %s: %s", pid, signal);
+ Runtime.getRuntime().exec(format("kill -%s %s", signal, pid));
+ return process;
+ }
+ catch (IOException e) {
+ log.warn(e, "Failure sending signal to process %s", pid);
+ return null;
+ }
+ }
+
+ private static long getPid(Process p)
+ {
+ try {
+ if (p.getClass().getName().equals("java.lang.UNIXProcess")) {
+ Field f = p.getClass().getDeclaredField("pid");
+ f.setAccessible(true);
+ long pid = f.getLong(p);
+ f.setAccessible(false);
+ return pid;
+ }
+ return -1;
+ }
+ catch (NoSuchFieldException | IllegalAccessException e) {
+ // should not happen
+ throw new AssertionError(e);
+ }
+ }
+
+ @Override
+ public void close()
+ {
+ Process process = this.process;
+ if (process == null) {
+ return;
+ }
+
+ if (process.isAlive()) {
+ long pid = getPid(process);
+ log.info("Destroying process: %s", pid);
+ process.destroy();
+ try {
+ // This 1 sec is arbitrary. Ideally, we do not need to be give any heads up
+ // to CPP process on presto-on-spark native, because the resources
+ // are reclaimed by the container manager.
+ // For localmode, we still want to provide an opportunity for
+ // graceful termination as there is no resource/container manager.
+ process.waitFor(1, TimeUnit.SECONDS);
+ }
+ catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ }
+ finally {
+ if (process.isAlive()) {
+ log.warn("Graceful shutdown of native execution process failed. Force killing it: %s", pid);
+ process.destroyForcibly();
+ }
+ }
+ }
+ else {
+ log.info("Process is dead: %s", getPid(process));
+ }
+ }
+
+ public boolean isAlive()
+ {
+ return process != null && process.isAlive();
+ }
+
+ public String getCrashReport()
+ {
+ ProcessOutputPipe pipe = processOutputPipe;
+ if (pipe == null) {
+ return "";
+ }
+ return pipe.getAbortMessage();
+ }
+
+ public int getPort()
+ {
+ return port;
+ }
+
+ public URI getLocation()
+ {
+ return location;
+ }
+
+ private static int getAvailableTcpPort(String nodeInternalAddress)
+ {
+ try {
+ ServerSocket socket = new ServerSocket();
+ socket.bind(new InetSocketAddress(nodeInternalAddress, 0));
+ int port = socket.getLocalPort();
+ socket.close();
+ return port;
+ }
+ catch (Exception ex) {
+ // Something is wrong with the executor
+ // Fail the executor
+ throw new PrestoSparkFatalException("Failed to acquire port on host", ex);
+ }
+ }
+
+ private String getNativeExecutionCatalogName(Session session)
+ {
+ checkArgument(session.getCatalog().isPresent(), "Catalog isn't set in the session.");
+ return session.getCatalog().get();
+ }
+
+ private void populateConfigurationFiles(String configBasePath)
+ throws IOException
+ {
+ // The reason we have to pick and assign the port per worker is in our prod environment,
+ // there is no port isolation among all the containers running on the same host, so we have
+ // to pick unique port per worker to avoid port collision. This config will be passed down to
+ // the native execution process eventually for process initialization.
+ workerProperty.getSystemConfig().setHttpServerPort(port);
+ workerProperty.populateAllProperties(
+ Paths.get(configBasePath, WORKER_CONFIG_FILE),
+ Paths.get(configBasePath, WORKER_NODE_CONFIG_FILE),
+ Paths.get(configBasePath, format("%s%s.properties", WORKER_CONNECTOR_CONFIG_FILE, getNativeExecutionCatalogName(session))));
+ }
+
+ private void doGetServerInfo(SettableFuture future)
+ {
+ addCallback(serverClient.getServerInfo(), new FutureCallback>()
+ {
+ @Override
+ public void onSuccess(@Nullable BaseResponse response)
+ {
+ if (response.getStatusCode() != OK.code()) {
+ throw new PrestoException(GENERIC_INTERNAL_ERROR, "Request failed with HTTP status " + response.getStatusCode());
+ }
+ future.set(response.getValue());
+ }
+
+ @Override
+ public void onFailure(Throwable failedReason)
+ {
+ if (failedReason instanceof RejectedExecutionException && httpClient.isClosed()) {
+ log.error(format("Unable to start the native process. HTTP client is closed. Reason: %s", failedReason.getMessage()));
+ future.setException(failedReason);
+ return;
+ }
+ // record failure
+ try {
+ errorTracker.requestFailed(failedReason);
+ }
+ catch (PrestoException e) {
+ future.setException(e);
+ return;
+ }
+ // if throttled due to error, asynchronously wait for timeout and try again
+ ListenableFuture> errorRateLimit = errorTracker.acquireRequestPermit();
+ if (errorRateLimit.isDone()) {
+ doGetServerInfo(future);
+ }
+ else {
+ errorRateLimit.addListener(() -> doGetServerInfo(future), executor);
+ }
+ }
+ }, directExecutor());
+ }
+
+ private String getProcessWorkingPath(String path)
+ {
+ File absolutePath = new File(path);
+ // In the case of SparkEnv is not initialed (e.g. unit test), we just use current location instead of calling SparkFiles.getRootDirectory() to avoid error.
+ String rootDirectory = SparkEnv$.MODULE$.get() != null ? SparkFiles.getRootDirectory() : ".";
+ File workingDir = new File(rootDirectory);
+ if (!absolutePath.isAbsolute()) {
+ absolutePath = new File(workingDir, path);
+ }
+
+ if (!absolutePath.exists()) {
+ log.error(format("File doesn't exist %s", absolutePath.getAbsolutePath()));
+ throw new PrestoException(NATIVE_EXECUTION_BINARY_NOT_EXIST, format("File doesn't exist %s", absolutePath.getAbsolutePath()));
+ }
+
+ return absolutePath.getAbsolutePath();
+ }
+
+ private List getLaunchCommand()
+ throws IOException
+ {
+ String configPath = Paths.get(getProcessWorkingPath("./"), String.valueOf(port)).toAbsolutePath().toString();
+ ImmutableList.Builder command = ImmutableList.builder();
+ List argsList = Arrays.asList(programArguments.split("\\s+"));
+ boolean etcDirSet = false;
+ for (int i = 0; i < argsList.size(); i++) {
+ String arg = argsList.get(i);
+ if (arg.equals("--etc_dir")) {
+ etcDirSet = true;
+ configPath = argsList.get(i + 1);
+ break;
+ }
+ }
+ command.add(executablePath).addAll(argsList);
+ if (!etcDirSet) {
+ command.add("--etc_dir").add(configPath);
+ populateConfigurationFiles(configPath);
+ }
+ ImmutableList commandList = command.build();
+ log.info("Launching native process using command: %s", String.join(" ", commandList));
+ return commandList;
+ }
+
+ private static class ProcessOutputPipe
+ implements Runnable
+ {
+ private final long pid;
+ private final InputStream inputStream;
+ private final OutputStream outputStream;
+ private final StringBuilder abortMessage = new StringBuilder();
+ private final AtomicBoolean started = new AtomicBoolean();
+
+ public ProcessOutputPipe(long pid, InputStream inputStream, OutputStream outputStream)
+ {
+ this.pid = pid;
+ this.inputStream = requireNonNull(inputStream, "inputStream is null");
+ this.outputStream = requireNonNull(outputStream, "outputStream is null");
+ }
+
+ public void start()
+ {
+ if (!started.compareAndSet(false, true)) {
+ return;
+ }
+ Thread t = new Thread(this, format("NativeExecutionProcess#ProcessOutputPipe[%s]", pid));
+ t.setDaemon(true);
+ t.start();
+ }
+
+ @Override
+ public void run()
+ {
+ try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, UTF_8));
+ BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(outputStream, UTF_8))) {
+ String line;
+ boolean aborted = false;
+ while ((line = reader.readLine()) != null) {
+ if (!aborted && line.startsWith("*** Aborted")) {
+ aborted = true;
+ }
+ if (aborted) {
+ synchronized (abortMessage) {
+ abortMessage.append(line).append("\n");
+ }
+ }
+ writer.write(line);
+ writer.newLine();
+ writer.flush();
+ }
+ }
+ catch (IOException e) {
+ log.warn(e, "failure occurred when copying streams");
+ }
+ }
+
+ public String getAbortMessage()
+ {
+ synchronized (abortMessage) {
+ return abortMessage.toString();
+ }
+ }
+ }
+}
diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/nativeprocess/NativeExecutionProcessFactory.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/nativeprocess/NativeExecutionProcessFactory.java
new file mode 100644
index 0000000000000..cb3b677630b50
--- /dev/null
+++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/nativeprocess/NativeExecutionProcessFactory.java
@@ -0,0 +1,116 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.spark.execution.nativeprocess;
+
+import com.facebook.airlift.http.client.HttpClient;
+import com.facebook.airlift.json.JsonCodec;
+import com.facebook.presto.Session;
+import com.facebook.presto.client.ServerInfo;
+import com.facebook.presto.spark.execution.property.WorkerProperty;
+import com.facebook.presto.spark.execution.task.ForNativeExecutionTask;
+import com.facebook.presto.spi.PrestoException;
+import com.facebook.presto.sql.analyzer.FeaturesConfig;
+import io.airlift.units.Duration;
+
+import javax.annotation.PreDestroy;
+import javax.inject.Inject;
+
+import java.io.IOException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+
+import static com.facebook.presto.SystemSessionProperties.isNativeExecutionProcessReuseEnabled;
+import static com.facebook.presto.spi.StandardErrorCode.NATIVE_EXECUTION_PROCESS_LAUNCH_ERROR;
+import static java.lang.String.format;
+import static java.util.Objects.requireNonNull;
+
+public class NativeExecutionProcessFactory
+{
+ private static final Duration MAX_ERROR_DURATION = new Duration(2, TimeUnit.MINUTES);
+ private final HttpClient httpClient;
+ private final ExecutorService coreExecutor;
+ private final ScheduledExecutorService errorRetryScheduledExecutor;
+ private final JsonCodec serverInfoCodec;
+ private final WorkerProperty, ?, ?, ?> workerProperty;
+ private final String executablePath;
+ private final String programArguments;
+
+ private static NativeExecutionProcess process;
+
+ @Inject
+ public NativeExecutionProcessFactory(
+ @ForNativeExecutionTask HttpClient httpClient,
+ ExecutorService coreExecutor,
+ ScheduledExecutorService errorRetryScheduledExecutor,
+ JsonCodec serverInfoCodec,
+ WorkerProperty, ?, ?, ?> workerProperty,
+ FeaturesConfig featuresConfig)
+ {
+ this.httpClient = requireNonNull(httpClient, "httpClient is null");
+ this.coreExecutor = requireNonNull(coreExecutor, "coreExecutor is null");
+ this.errorRetryScheduledExecutor = requireNonNull(errorRetryScheduledExecutor, "errorRetryScheduledExecutor is null");
+ this.serverInfoCodec = requireNonNull(serverInfoCodec, "serverInfoCodec is null");
+ this.workerProperty = requireNonNull(workerProperty, "workerProperty is null");
+ this.executablePath = featuresConfig.getNativeExecutionExecutablePath();
+ this.programArguments = featuresConfig.getNativeExecutionProgramArguments();
+ }
+
+ public synchronized NativeExecutionProcess getNativeExecutionProcess(Session session)
+ {
+ if (!isNativeExecutionProcessReuseEnabled(session) || process == null || !process.isAlive()) {
+ process = createNativeExecutionProcess(session, MAX_ERROR_DURATION);
+ }
+ return process;
+ }
+
+ public NativeExecutionProcess createNativeExecutionProcess(Session session, Duration maxErrorDuration)
+ {
+ try {
+ return new NativeExecutionProcess(
+ executablePath,
+ programArguments,
+ session,
+ httpClient,
+ coreExecutor,
+ errorRetryScheduledExecutor,
+ serverInfoCodec,
+ maxErrorDuration,
+ workerProperty);
+ }
+ catch (IOException e) {
+ throw new PrestoException(NATIVE_EXECUTION_PROCESS_LAUNCH_ERROR, format("Cannot start native process: %s", e.getMessage()), e);
+ }
+ }
+
+ @PreDestroy
+ public void stop()
+ {
+ coreExecutor.shutdownNow();
+ errorRetryScheduledExecutor.shutdownNow();
+ if (process != null) {
+ process.close();
+ }
+ }
+
+ protected String getExecutablePath()
+ {
+ return executablePath;
+ }
+
+ protected String getProgramArguments()
+ {
+ return programArguments;
+ }
+}
diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/task/ForNativeExecutionTask.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/task/ForNativeExecutionTask.java
new file mode 100644
index 0000000000000..a49c45dd2ad1c
--- /dev/null
+++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/task/ForNativeExecutionTask.java
@@ -0,0 +1,31 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.spark.execution.task;
+
+import javax.inject.Qualifier;
+
+import java.lang.annotation.Retention;
+import java.lang.annotation.Target;
+
+import static java.lang.annotation.ElementType.FIELD;
+import static java.lang.annotation.ElementType.METHOD;
+import static java.lang.annotation.ElementType.PARAMETER;
+import static java.lang.annotation.RetentionPolicy.RUNTIME;
+
+@Retention(RUNTIME)
+@Target({FIELD, PARAMETER, METHOD})
+@Qualifier
+public @interface ForNativeExecutionTask
+{
+}
diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/task/NativeExecutionTask.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/task/NativeExecutionTask.java
new file mode 100644
index 0000000000000..fa08889dd8241
--- /dev/null
+++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/task/NativeExecutionTask.java
@@ -0,0 +1,188 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.spark.execution.task;
+
+import com.facebook.airlift.log.Logger;
+import com.facebook.presto.Session;
+import com.facebook.presto.execution.TaskInfo;
+import com.facebook.presto.execution.TaskManagerConfig;
+import com.facebook.presto.execution.TaskSource;
+import com.facebook.presto.execution.buffer.OutputBuffers;
+import com.facebook.presto.execution.scheduler.TableWriteInfo;
+import com.facebook.presto.spark.execution.http.PrestoSparkHttpTaskClient;
+import com.facebook.presto.spark.execution.nativeprocess.HttpNativeExecutionTaskInfoFetcher;
+import com.facebook.presto.spark.execution.nativeprocess.HttpNativeExecutionTaskResultFetcher;
+import com.facebook.presto.spi.page.SerializedPage;
+import com.facebook.presto.sql.planner.PlanFragment;
+import com.google.common.collect.ImmutableList;
+
+import java.util.List;
+import java.util.Optional;
+import java.util.concurrent.ScheduledExecutorService;
+
+import static com.facebook.presto.execution.TaskState.ABORTED;
+import static com.facebook.presto.execution.TaskState.CANCELED;
+import static com.facebook.presto.execution.TaskState.FAILED;
+import static com.facebook.presto.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers;
+import static java.util.Objects.requireNonNull;
+
+/**
+ * NativeExecutionTask provide the abstraction of executing tasks via c++ worker.
+ * It is used for native execution of Presto on Spark. The plan and splits is provided at creation
+ * of NativeExecutionTask, it doesn't support adding more splits during execution.
+ *
+ * Caller should manage the lifecycle of the task by exposed APIs. The general workflow will look like:
+ * 1. Caller shall start() to start the task.
+ * 2. Caller shall call getTaskInfo() any time to get the current task info. Until the caller calls stop(), the task info fetcher will not stop fetching task info.
+ * 3. Caller shall call pollResult() continuously to poll result page from the internal buffer. The result fetcher will stop fetching more results if buffer hits its memory cap,
+ * until pages are fetched by caller to reduce the buffer under its memory cap.
+ * 4. Caller must call stop() to release resource.
+ */
+public class NativeExecutionTask
+{
+ private static final Logger log = Logger.get(NativeExecutionTask.class);
+
+ private final Session session;
+ private final PlanFragment planFragment;
+ private final OutputBuffers outputBuffers;
+ private final PrestoSparkHttpTaskClient workerClient;
+ private final TableWriteInfo tableWriteInfo;
+ private final Optional shuffleWriteInfo;
+ private final Optional broadcastBasePath;
+ private final List sources;
+ private final HttpNativeExecutionTaskInfoFetcher taskInfoFetcher;
+ // Results will be fetched only if not written to shuffle.
+ private final Optional taskResultFetcher;
+ private final Object taskFinishedOrHasResult = new Object();
+
+ public NativeExecutionTask(
+ Session session,
+ PrestoSparkHttpTaskClient workerClient,
+ PlanFragment planFragment,
+ List sources,
+ TableWriteInfo tableWriteInfo,
+ Optional shuffleWriteInfo,
+ Optional broadcastBasePath,
+ ScheduledExecutorService scheduledExecutorService,
+ TaskManagerConfig taskManagerConfig)
+ {
+ this.session = requireNonNull(session, "session is null");
+ this.planFragment = requireNonNull(planFragment, "planFragment is null");
+ this.tableWriteInfo = requireNonNull(tableWriteInfo, "tableWriteInfo is null");
+ this.shuffleWriteInfo = requireNonNull(shuffleWriteInfo, "shuffleWriteInfo is null");
+ this.broadcastBasePath = requireNonNull(broadcastBasePath, "broadcastBasePath is null");
+ this.sources = requireNonNull(sources, "sources is null");
+ this.workerClient = requireNonNull(workerClient, "workerClient is null");
+ this.outputBuffers = createInitialEmptyOutputBuffers(planFragment.getPartitioningScheme().getPartitioning().getHandle()).withNoMoreBufferIds();
+ requireNonNull(taskManagerConfig, "taskManagerConfig is null");
+ requireNonNull(scheduledExecutorService, "scheduledExecutorService is null");
+ this.taskInfoFetcher = new HttpNativeExecutionTaskInfoFetcher(
+ scheduledExecutorService,
+ this.workerClient,
+ taskManagerConfig.getInfoUpdateInterval(),
+ taskFinishedOrHasResult);
+ if (!shuffleWriteInfo.isPresent()) {
+ this.taskResultFetcher = Optional.of(new HttpNativeExecutionTaskResultFetcher(
+ scheduledExecutorService,
+ this.workerClient,
+ taskFinishedOrHasResult));
+ }
+ else {
+ this.taskResultFetcher = Optional.empty();
+ }
+ }
+
+ /**
+ * Gets the most updated {@link TaskInfo} of the task of the native task.
+ *
+ * @return an {@link Optional} of most updated {@link TaskInfo}, empty {@link Optional} if {@link HttpNativeExecutionTaskInfoFetcher} has not yet retrieved the very first
+ * TaskInfo.
+ */
+ public Optional getTaskInfo()
+ throws RuntimeException
+ {
+ return taskInfoFetcher.getTaskInfo();
+ }
+
+ public boolean isTaskDone()
+ {
+ Optional taskInfo = getTaskInfo();
+ return taskInfo.isPresent() && taskInfo.get().getTaskStatus().getState().isDone();
+ }
+
+ public Object getTaskFinishedOrHasResult()
+ {
+ return taskFinishedOrHasResult;
+ }
+
+ /**
+ * Blocking call to poll from result fetcher buffer. Blocks until content becomes available in the buffer, or until timeout is hit.
+ *
+ * @return an Optional of the first {@link SerializedPage} result fetcher buffer contains, an empty Optional if no result is in the buffer.
+ */
+ public Optional pollResult()
+ throws InterruptedException
+ {
+ if (!taskResultFetcher.isPresent()) {
+ return Optional.empty();
+ }
+ return taskResultFetcher.get().pollPage();
+ }
+
+ public boolean hasResult()
+ {
+ return taskResultFetcher.isPresent() && taskResultFetcher.get().hasPage();
+ }
+
+ /**
+ * Blocking call to create and start native task.
+ *
+ * Starts background threads to fetch results and updated info.
+ */
+ public TaskInfo start()
+ {
+ TaskInfo taskInfo = sendUpdateRequest();
+
+ // We do not start taskInfo fetcher for failed tasks
+ if (!ImmutableList.of(CANCELED, FAILED, ABORTED).contains(taskInfo.getTaskStatus().getState())) {
+ log.info("Starting TaskInfoFetcher and TaskResultFetcher.");
+ taskResultFetcher.ifPresent(fetcher -> fetcher.start());
+ taskInfoFetcher.start();
+ }
+
+ return taskInfo;
+ }
+
+ /**
+ * Releases all resources, and kills all schedulers. It is caller's responsibility to call this method when NativeExecutionTask is no longer needed.
+ */
+ public void stop(boolean success)
+ {
+ taskInfoFetcher.stop();
+ taskResultFetcher.ifPresent(fetcher -> fetcher.stop(success));
+ workerClient.abortResultsAsync();
+ }
+
+ private TaskInfo sendUpdateRequest()
+ {
+ return workerClient.updateTask(
+ sources,
+ planFragment,
+ tableWriteInfo,
+ shuffleWriteInfo,
+ broadcastBasePath,
+ session,
+ outputBuffers);
+ }
+}
diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/task/NativeExecutionTaskFactory.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/task/NativeExecutionTaskFactory.java
new file mode 100644
index 0000000000000..dec6882af4af7
--- /dev/null
+++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/task/NativeExecutionTaskFactory.java
@@ -0,0 +1,163 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.spark.execution.task;
+
+import com.facebook.airlift.concurrent.BoundedExecutor;
+import com.facebook.airlift.http.client.HttpClient;
+import com.facebook.airlift.json.JsonCodec;
+import com.facebook.presto.Session;
+import com.facebook.presto.execution.QueryManagerConfig;
+import com.facebook.presto.execution.TaskId;
+import com.facebook.presto.execution.TaskInfo;
+import com.facebook.presto.execution.TaskManagerConfig;
+import com.facebook.presto.execution.TaskSource;
+import com.facebook.presto.execution.scheduler.TableWriteInfo;
+import com.facebook.presto.spark.execution.http.BatchTaskUpdateRequest;
+import com.facebook.presto.spark.execution.http.PrestoSparkHttpTaskClient;
+import com.facebook.presto.sql.planner.PlanFragment;
+
+import javax.annotation.PreDestroy;
+import javax.inject.Inject;
+
+import java.net.URI;
+import java.util.List;
+import java.util.Optional;
+import java.util.concurrent.Executor;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.ScheduledExecutorService;
+
+import static java.util.Objects.requireNonNull;
+
+public class NativeExecutionTaskFactory
+{
+ // TODO add config
+ private static final int MAX_THREADS = 1000;
+
+ private final HttpClient httpClient;
+ private final ExecutorService coreExecutor;
+ private final Executor executor;
+ private final ScheduledExecutorService scheduledExecutorService;
+ private final JsonCodec taskInfoCodec;
+ private final JsonCodec planFragmentCodec;
+ private final JsonCodec taskUpdateRequestCodec;
+ private final TaskManagerConfig taskManagerConfig;
+ private final QueryManagerConfig queryManagerConfig;
+
+ @Inject
+ public NativeExecutionTaskFactory(
+ @ForNativeExecutionTask HttpClient httpClient,
+ ExecutorService coreExecutor,
+ ScheduledExecutorService scheduledExecutorService,
+ JsonCodec taskInfoCodec,
+ JsonCodec planFragmentCodec,
+ JsonCodec taskUpdateRequestCodec,
+ TaskManagerConfig taskManagerConfig,
+ QueryManagerConfig queryManagerConfig)
+ {
+ this.httpClient = requireNonNull(httpClient, "httpClient is null");
+ this.coreExecutor = requireNonNull(coreExecutor, "coreExecutor is null");
+ this.executor = new BoundedExecutor(coreExecutor, MAX_THREADS);
+ this.scheduledExecutorService = requireNonNull(scheduledExecutorService, "scheduledExecutorService is null");
+ this.taskInfoCodec = requireNonNull(taskInfoCodec, "taskInfoCodec is null");
+ this.planFragmentCodec = requireNonNull(planFragmentCodec, "planFragmentCodec is null");
+ this.taskUpdateRequestCodec = requireNonNull(taskUpdateRequestCodec, "taskUpdateRequestCodec is null");
+ this.taskManagerConfig = requireNonNull(taskManagerConfig, "taskManagerConfig is null");
+ this.queryManagerConfig = requireNonNull(queryManagerConfig, "queryManagerConfig is null");
+ }
+
+ public NativeExecutionTask createNativeExecutionTask(
+ Session session,
+ URI location,
+ TaskId taskId,
+ PlanFragment fragment,
+ List sources,
+ TableWriteInfo tableWriteInfo,
+ Optional shuffleWriteInfo,
+ Optional broadcastBasePath)
+ {
+ PrestoSparkHttpTaskClient workerClient = new PrestoSparkHttpTaskClient(
+ httpClient,
+ taskId,
+ location,
+ taskInfoCodec,
+ planFragmentCodec,
+ taskUpdateRequestCodec,
+ taskManagerConfig.getInfoRefreshMaxWait(),
+ executor,
+ scheduledExecutorService,
+ queryManagerConfig.getRemoteTaskMaxErrorDuration());
+ return new NativeExecutionTask(
+ session,
+ workerClient,
+ fragment,
+ sources,
+ tableWriteInfo,
+ shuffleWriteInfo,
+ broadcastBasePath,
+ scheduledExecutorService,
+ taskManagerConfig);
+ }
+
+ @PreDestroy
+ public void stop()
+ {
+ coreExecutor.shutdownNow();
+ scheduledExecutorService.shutdownNow();
+ }
+
+ public HttpClient getHttpClient()
+ {
+ return httpClient;
+ }
+
+ public ExecutorService getCoreExecutor()
+ {
+ return coreExecutor;
+ }
+
+ public Executor getExecutor()
+ {
+ return executor;
+ }
+
+ public ScheduledExecutorService getScheduledExecutorService()
+ {
+ return scheduledExecutorService;
+ }
+
+ public JsonCodec getTaskInfoCodec()
+ {
+ return taskInfoCodec;
+ }
+
+ public JsonCodec getPlanFragmentCodec()
+ {
+ return planFragmentCodec;
+ }
+
+ public JsonCodec getTaskUpdateRequestCodec()
+ {
+ return taskUpdateRequestCodec;
+ }
+
+ public TaskManagerConfig getTaskManagerConfig()
+ {
+ return taskManagerConfig;
+ }
+
+ public QueryManagerConfig getQueryManagerConfig()
+ {
+ return queryManagerConfig;
+ }
+}
diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/task/PrestoSparkNativeTaskExecutorFactory.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/task/PrestoSparkNativeTaskExecutorFactory.java
new file mode 100644
index 0000000000000..48acd0bbce25b
--- /dev/null
+++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/task/PrestoSparkNativeTaskExecutorFactory.java
@@ -0,0 +1,705 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.spark.execution.task;
+
+import com.facebook.airlift.json.Codec;
+import com.facebook.airlift.json.JsonCodec;
+import com.facebook.airlift.log.Logger;
+import com.facebook.presto.Session;
+import com.facebook.presto.common.RuntimeUnit;
+import com.facebook.presto.common.block.BlockEncodingManager;
+import com.facebook.presto.common.type.VarcharType;
+import com.facebook.presto.execution.ExecutionFailureInfo;
+import com.facebook.presto.execution.Lifespan;
+import com.facebook.presto.execution.Location;
+import com.facebook.presto.execution.ScheduledSplit;
+import com.facebook.presto.execution.StageExecutionId;
+import com.facebook.presto.execution.StageId;
+import com.facebook.presto.execution.TaskId;
+import com.facebook.presto.execution.TaskInfo;
+import com.facebook.presto.execution.TaskSource;
+import com.facebook.presto.execution.TaskState;
+import com.facebook.presto.metadata.RemoteTransactionHandle;
+import com.facebook.presto.metadata.SessionPropertyManager;
+import com.facebook.presto.metadata.Split;
+import com.facebook.presto.spark.PrestoSparkTaskDescriptor;
+import com.facebook.presto.spark.accesscontrol.PrestoSparkAuthenticatorProvider;
+import com.facebook.presto.spark.classloader_interface.IPrestoSparkTaskExecutor;
+import com.facebook.presto.spark.classloader_interface.IPrestoSparkTaskExecutorFactory;
+import com.facebook.presto.spark.classloader_interface.MutablePartitionId;
+import com.facebook.presto.spark.classloader_interface.PrestoSparkNativeTaskInputs;
+import com.facebook.presto.spark.classloader_interface.PrestoSparkSerializedPage;
+import com.facebook.presto.spark.classloader_interface.PrestoSparkShuffleReadDescriptor;
+import com.facebook.presto.spark.classloader_interface.PrestoSparkShuffleStats;
+import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskInputs;
+import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskOutput;
+import com.facebook.presto.spark.classloader_interface.SerializedPrestoSparkTaskDescriptor;
+import com.facebook.presto.spark.classloader_interface.SerializedPrestoSparkTaskSource;
+import com.facebook.presto.spark.classloader_interface.SerializedTaskInfo;
+import com.facebook.presto.spark.execution.BroadcastFileInfo;
+import com.facebook.presto.spark.execution.PrestoSparkBroadcastTableCacheManager;
+import com.facebook.presto.spark.execution.PrestoSparkExecutionExceptionFactory;
+import com.facebook.presto.spark.execution.nativeprocess.NativeExecutionProcess;
+import com.facebook.presto.spark.execution.nativeprocess.NativeExecutionProcessFactory;
+import com.facebook.presto.spark.execution.shuffle.PrestoSparkShuffleInfoTranslator;
+import com.facebook.presto.spark.execution.shuffle.PrestoSparkShuffleWriteInfo;
+import com.facebook.presto.spark.util.PrestoSparkStatsCollectionUtils;
+import com.facebook.presto.spark.util.PrestoSparkUtils;
+import com.facebook.presto.spi.PrestoException;
+import com.facebook.presto.spi.PrestoTransportException;
+import com.facebook.presto.spi.page.PagesSerde;
+import com.facebook.presto.spi.page.SerializedPage;
+import com.facebook.presto.spi.plan.PlanFragmentId;
+import com.facebook.presto.spi.plan.PlanNode;
+import com.facebook.presto.spi.plan.PlanNodeId;
+import com.facebook.presto.spi.plan.TableWriterNode;
+import com.facebook.presto.spi.security.TokenAuthenticator;
+import com.facebook.presto.split.RemoteSplit;
+import com.facebook.presto.sql.planner.PlanFragment;
+import com.facebook.presto.sql.planner.plan.RemoteSourceNode;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Sets;
+import com.sun.management.OperatingSystemMXBean;
+import io.airlift.units.Duration;
+import org.apache.spark.broadcast.Broadcast;
+import org.apache.spark.util.CollectionAccumulator;
+import scala.Tuple2;
+import scala.collection.AbstractIterator;
+import scala.collection.Iterator;
+
+import javax.inject.Inject;
+
+import java.io.IOException;
+import java.lang.management.ManagementFactory;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+import java.util.OptionalLong;
+import java.util.Set;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+import static com.facebook.presto.operator.ExchangeOperator.REMOTE_CONNECTOR_ID;
+import static com.facebook.presto.spark.PrestoSparkSessionProperties.getNativeExecutionBroadcastBasePath;
+import static com.facebook.presto.spark.PrestoSparkSessionProperties.getNativeTerminateWithCoreTimeout;
+import static com.facebook.presto.spark.PrestoSparkSessionProperties.isNativeTerminateWithCoreWhenUnresponsiveEnabled;
+import static com.facebook.presto.spark.util.PrestoSparkUtils.deserializeZstdCompressed;
+import static com.facebook.presto.spark.util.PrestoSparkUtils.serializeZstdCompressed;
+import static com.facebook.presto.spark.util.PrestoSparkUtils.toPrestoSparkSerializedPage;
+import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
+import static com.facebook.presto.spi.StandardErrorCode.TOO_MANY_REQUESTS_FAILED;
+import static com.facebook.presto.sql.planner.SchedulingOrderVisitor.scheduleOrder;
+import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION;
+import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom;
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.collect.ImmutableSet.toImmutableSet;
+import static io.airlift.units.DataSize.succinctBytes;
+import static java.lang.String.format;
+import static java.util.Objects.requireNonNull;
+
+/**
+ * PrestoSparkNativeTaskExecutorFactory is responsible for launching the external native process and managing the communication
+ * between Java process and native process (by using the {@Link NativeExecutionTask}).
+ * It will send necessary metadata (e.g, plan fragment, session properties etc.) as a part of
+ * BatchTaskUpdateRequest. It will poll the remote CPP task for status and results (pages/data if applicable)
+ * and send these back to the Spark's RDD api
+ *
+ * PrestoSparkNativeTaskExecutorFactory is singleton instantiated once per executor.
+ *
+ * For every task it receives, it does the following
+ * 1. Create the Native execution Process (NativeTaskExecutionFactory) ensure that is it created only once.
+ * 2. Serialize and pass the planFragment, source-metadata (taskSources), sink-metadata (tableWriteInfo or shuffleWriteInfo)
+ * and submit a nativeExecutionTask.
+ * 3. Return Iterator to sparkRDD layer. RDD execution will call the .next() methods, which will
+ * 3.a Call {@link NativeExecutionTask}'s pollResult() to retrieve {@link SerializedPage} back from external process.
+ * 3.b If no more output is available, then check if task has finished successfully or with exception
+ * If task finished with exception - fail the spark task (throw exception)
+ * IF task finished successfully - collect statistics through taskInfo object and add to accumulator
+ */
+public class PrestoSparkNativeTaskExecutorFactory
+ implements IPrestoSparkTaskExecutorFactory
+{
+ private static final Logger log = Logger.get(PrestoSparkNativeTaskExecutorFactory.class);
+
+ // For Presto-on-Spark, we do not have remoteSourceTasks as the shuffle data is
+ // in persistent shuffle.
+ // Current protocol for Split mandates having a remoteSourceTaskId as the
+ // part of the split info. So for shuffleRead split we set it to a dummy
+ // value that is ignored by the shuffle-reader
+ private static final TaskId DUMMY_TASK_ID = TaskId.valueOf("remotesourcetaskid.0.0.0.0");
+
+ private final SessionPropertyManager sessionPropertyManager;
+ private final JsonCodec taskDescriptorJsonCodec;
+ private final JsonCodec broadcastFileInfoJsonCodec;
+ private final Codec taskSourceCodec;
+ private final Codec taskInfoCodec;
+ private final PrestoSparkExecutionExceptionFactory executionExceptionFactory;
+ private final Set authenticatorProviders;
+ private final NativeExecutionProcessFactory nativeExecutionProcessFactory;
+ private final NativeExecutionTaskFactory nativeExecutionTaskFactory;
+ private final PrestoSparkShuffleInfoTranslator shuffleInfoTranslator;
+ private final PagesSerde pagesSerde;
+ private NativeExecutionProcess nativeExecutionProcess;
+
+ private static class CpuTracker
+ {
+ private OperatingSystemMXBean operatingSystemMXBean;
+ private OptionalLong startCpuTime;
+
+ public CpuTracker()
+ {
+ if (ManagementFactory.getOperatingSystemMXBean() instanceof OperatingSystemMXBean) {
+ // we want the com.sun.management sub-interface of java.lang.management.OperatingSystemMXBean
+ operatingSystemMXBean = (OperatingSystemMXBean) ManagementFactory.getOperatingSystemMXBean();
+ startCpuTime = OptionalLong.of(operatingSystemMXBean.getProcessCpuTime());
+ }
+ else {
+ startCpuTime = OptionalLong.empty();
+ }
+ }
+
+ OptionalLong get()
+ {
+ if (operatingSystemMXBean != null) {
+ long endCpuTime = operatingSystemMXBean.getProcessCpuTime();
+ return OptionalLong.of(endCpuTime - startCpuTime.getAsLong());
+ }
+ else {
+ return OptionalLong.empty();
+ }
+ }
+ }
+
+ @Inject
+ public PrestoSparkNativeTaskExecutorFactory(
+ SessionPropertyManager sessionPropertyManager,
+ BlockEncodingManager blockEncodingManager,
+ JsonCodec taskDescriptorJsonCodec,
+ JsonCodec broadcastFileInfoJsonCodec,
+ Codec taskSourceCodec,
+ Codec taskInfoCodec,
+ PrestoSparkExecutionExceptionFactory executionExceptionFactory,
+ Set authenticatorProviders,
+ PrestoSparkBroadcastTableCacheManager prestoSparkBroadcastTableCacheManager,
+ NativeExecutionProcessFactory nativeExecutionProcessFactory,
+ NativeExecutionTaskFactory nativeExecutionTaskFactory,
+ PrestoSparkShuffleInfoTranslator shuffleInfoTranslator)
+ {
+ this.sessionPropertyManager = requireNonNull(sessionPropertyManager, "sessionPropertyManager is null");
+ this.taskDescriptorJsonCodec = requireNonNull(taskDescriptorJsonCodec, "sparkTaskDescriptorJsonCodec is null");
+ this.taskSourceCodec = requireNonNull(taskSourceCodec, "taskSourceCodec is null");
+ this.taskInfoCodec = requireNonNull(taskInfoCodec, "taskInfoCodec is null");
+ this.broadcastFileInfoJsonCodec = requireNonNull(broadcastFileInfoJsonCodec, "broadcastFileInfoJsonCodec is null");
+ this.executionExceptionFactory = requireNonNull(executionExceptionFactory, "executionExceptionFactory is null");
+ this.authenticatorProviders = ImmutableSet.copyOf(requireNonNull(authenticatorProviders, "authenticatorProviders is null"));
+ this.nativeExecutionProcessFactory = requireNonNull(nativeExecutionProcessFactory, "processFactory is null");
+ this.nativeExecutionTaskFactory = requireNonNull(nativeExecutionTaskFactory, "taskFactory is null");
+ this.shuffleInfoTranslator = requireNonNull(shuffleInfoTranslator, "shuffleInfoFactory is null");
+ this.pagesSerde = PrestoSparkUtils.createPagesSerde(requireNonNull(blockEncodingManager, "blockEncodingManager is null"));
+ }
+
+ @Override
+ public IPrestoSparkTaskExecutor create(
+ int partitionId,
+ int attemptNumber,
+ SerializedPrestoSparkTaskDescriptor serializedTaskDescriptor,
+ Iterator serializedTaskSources,
+ PrestoSparkTaskInputs inputs,
+ CollectionAccumulator taskInfoCollector,
+ CollectionAccumulator shuffleStatsCollector,
+ Class outputType)
+ {
+ try {
+ return doCreate(
+ partitionId,
+ attemptNumber,
+ serializedTaskDescriptor,
+ serializedTaskSources,
+ inputs,
+ taskInfoCollector,
+ shuffleStatsCollector,
+ outputType);
+ }
+ catch (RuntimeException e) {
+ throw executionExceptionFactory.toPrestoSparkExecutionException(e);
+ }
+ }
+
+ public IPrestoSparkTaskExecutor doCreate(
+ int partitionId,
+ int attemptNumber,
+ SerializedPrestoSparkTaskDescriptor serializedTaskDescriptor,
+ Iterator serializedTaskSources,
+ PrestoSparkTaskInputs inputs,
+ CollectionAccumulator taskInfoCollector,
+ CollectionAccumulator shuffleStatsCollector,
+ Class outputType)
+ {
+ CpuTracker cpuTracker = new CpuTracker();
+
+ PrestoSparkTaskDescriptor taskDescriptor = taskDescriptorJsonCodec.fromJson(serializedTaskDescriptor.getBytes());
+ ImmutableMap.Builder extraAuthenticators = ImmutableMap.builder();
+ authenticatorProviders.forEach(provider -> extraAuthenticators.putAll(provider.getTokenAuthenticators()));
+
+ Session session = taskDescriptor.getSession().toSession(
+ sessionPropertyManager,
+ taskDescriptor.getExtraCredentials(),
+ extraAuthenticators.build());
+ PlanFragment fragment = taskDescriptor.getFragment();
+ StageId stageId = new StageId(session.getQueryId(), fragment.getId().getId());
+ TaskId taskId = new TaskId(new StageExecutionId(stageId, 0), partitionId, attemptNumber);
+
+ // TODO: Remove this once we can display the plan on Spark UI.
+ // Currently, `textPlanFragment` throws an exception if json-based UDFs are used in the query, which can only
+ // happen in native execution mode. To resolve this error, `JsonFileBasedFunctionNamespaceManager` must be
+ // loaded on the executors as well (which is actually not required for native execution). To do so, we need a
+ // mechanism to ship the JSON file containing the UDF metadata to workers, which does not exist as of today.
+ // TODO: Address this issue; more details in https://github.com/prestodb/presto/issues/19600
+ log.info("Logging plan fragment is not supported for presto-on-spark native execution, yet");
+
+ if (fragment.getPartitioning().isCoordinatorOnly()) {
+ throw new UnsupportedOperationException("Coordinator only fragment execution is not supported by native task executor");
+ }
+
+ checkArgument(
+ inputs instanceof PrestoSparkNativeTaskInputs,
+ format("PrestoSparkNativeTaskInputs is required for native execution, but %s is provided", inputs.getClass().getName()));
+
+ // 1. Start the native process if it hasn't already been started or dead
+ createAndStartNativeExecutionProcess(session);
+
+ // 2. compute the task info to send to cpp process
+ PrestoSparkNativeTaskInputs nativeInputs = (PrestoSparkNativeTaskInputs) inputs;
+ // 2.a Populate Read info
+ List taskSources = getTaskSources(serializedTaskSources, fragment, session, nativeInputs);
+
+ boolean isFixedBroadcastDistribution = fragment.getPartitioningScheme().getPartitioning().getHandle().equals(FIXED_BROADCAST_DISTRIBUTION);
+ // 2.b Populate Shuffle Write info
+ Optional shuffleWriteInfo = nativeInputs.getShuffleWriteDescriptor()
+ .map(descriptor -> shuffleInfoTranslator.createShuffleWriteInfo(session, descriptor));
+ Optional serializedShuffleWriteInfo = shuffleWriteInfo.map(shuffleInfoTranslator::createSerializedWriteInfo);
+
+ // 2.c populate broadcast path
+ Optional broadcastDirectory =
+ isFixedBroadcastDistribution ? Optional.of(getBroadcastDirectoryPath(session)) : Optional.empty();
+
+ boolean terminateWithCoreWhenUnresponsive = isNativeTerminateWithCoreWhenUnresponsiveEnabled(session);
+ Duration terminateWithCoreTimeout = getNativeTerminateWithCoreTimeout(session);
+ try {
+ // 3. Submit the task to cpp process for execution
+ log.info("Submitting native execution task ");
+ NativeExecutionTask task = nativeExecutionTaskFactory.createNativeExecutionTask(
+ session,
+ nativeExecutionProcess.getLocation(),
+ taskId,
+ fragment,
+ ImmutableList.copyOf(taskSources),
+ taskDescriptor.getTableWriteInfo(),
+ serializedShuffleWriteInfo,
+ broadcastDirectory);
+
+ log.info("Creating task and will wait for remote task completion");
+ TaskInfo taskInfo = task.start();
+
+ // task creation might have failed
+ processTaskInfoForErrorsOrCompletion(taskInfo);
+ // 4. return output to spark RDD layer
+ return new PrestoSparkNativeTaskOutputIterator<>(
+ partitionId,
+ task,
+ outputType,
+ taskInfoCollector,
+ taskInfoCodec,
+ executionExceptionFactory,
+ cpuTracker,
+ nativeExecutionProcess,
+ terminateWithCoreWhenUnresponsive,
+ terminateWithCoreTimeout);
+ }
+ catch (RuntimeException e) {
+ throw processFailure(e, nativeExecutionProcess, terminateWithCoreWhenUnresponsive, terminateWithCoreTimeout);
+ }
+ }
+
+ private String getBroadcastDirectoryPath(Session session)
+ {
+ return format("%s/%s", getNativeExecutionBroadcastBasePath(session), session.getQueryId().getId());
+ }
+
+ @Override
+ public void close()
+ {
+ if (nativeExecutionProcess != null) {
+ nativeExecutionProcess.close();
+ }
+ }
+
+ private static void completeTask(boolean success, CollectionAccumulator taskInfoCollector, NativeExecutionTask task, Codec taskInfoCodec, CpuTracker cpuTracker)
+ {
+ // stop the task
+ task.stop(success);
+
+ OptionalLong processCpuTime = cpuTracker.get();
+
+ // collect statistics (if available)
+ Optional taskInfoOptional = tryGetTaskInfo(task);
+ if (!taskInfoOptional.isPresent()) {
+ log.error("Missing taskInfo. Statistics might be inaccurate");
+ return;
+ }
+
+ // Record process-wide CPU time spent while executing this task. Since we run one task at a time,
+ // process-wide CPU time matches task's CPU time.
+ processCpuTime.ifPresent(cpuTime -> taskInfoOptional.get().getStats().getRuntimeStats()
+ .addMetricValue("javaProcessCpuTime", RuntimeUnit.NANO, cpuTime));
+
+ SerializedTaskInfo serializedTaskInfo = new SerializedTaskInfo(serializeZstdCompressed(taskInfoCodec, taskInfoOptional.get()));
+ taskInfoCollector.add(serializedTaskInfo);
+
+ // Update Spark Accumulators for spark internal metrics
+ PrestoSparkStatsCollectionUtils.collectMetrics(taskInfoOptional.get());
+ }
+
+ private static Optional tryGetTaskInfo(NativeExecutionTask task)
+ {
+ try {
+ return task.getTaskInfo();
+ }
+ catch (RuntimeException e) {
+ log.debug(e, "TaskInfo is not available");
+ return Optional.empty();
+ }
+ }
+
+ private static void processTaskInfoForErrorsOrCompletion(TaskInfo taskInfo)
+ {
+ if (!taskInfo.getTaskStatus().getState().isDone()) {
+ log.info("processTaskInfoForErrors: task is not done yet.. %s", taskInfo);
+ return;
+ }
+
+ if (!taskInfo.getTaskStatus().getState().equals(TaskState.FINISHED)) {
+ // task failed with errors
+ RuntimeException failure = taskInfo.getTaskStatus().getFailures().stream()
+ .findFirst()
+ .map(ExecutionFailureInfo::toException)
+ .orElseGet(() -> new PrestoException(GENERIC_INTERNAL_ERROR, "Native task failed for an unknown reason"));
+ throw failure;
+ }
+
+ log.info("processTaskInfoForErrors: task completed successfully = %s", taskInfo);
+ }
+
+ private void createAndStartNativeExecutionProcess(Session session)
+ {
+ requireNonNull(nativeExecutionProcessFactory, "Trying to instantiate native process but factory is null");
+
+ try {
+ // create the CPP sidecar process if it doesn't exist.
+ // We create this when the first task is scheduled
+ nativeExecutionProcess = nativeExecutionProcessFactory.getNativeExecutionProcess(session);
+ nativeExecutionProcess.start();
+ }
+ catch (ExecutionException | InterruptedException | IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private List getTaskSources(
+ Iterator serializedTaskSources,
+ PlanFragment fragment,
+ Session session,
+ PrestoSparkNativeTaskInputs nativeTaskInputs)
+ {
+ List taskSources = new ArrayList<>();
+
+ // Populate TableScan sources
+ long totalSerializedSizeInBytes = 0;
+ while (serializedTaskSources.hasNext()) {
+ SerializedPrestoSparkTaskSource serializedTaskSource = serializedTaskSources.next();
+ taskSources.add(deserializeZstdCompressed(taskSourceCodec, serializedTaskSource.getBytes()));
+ totalSerializedSizeInBytes += serializedTaskSource.getBytes().length;
+ }
+
+ // When joining bucketed table with a non-bucketed table with a filter on "$bucket",
+ // some tasks may not have splits for the bucketed table. In this case we still need
+ // to send no-more-splits message to Velox.
+ Set planNodeIdsWithSources = taskSources.stream().map(TaskSource::getPlanNodeId).collect(Collectors.toSet());
+ Set tableScanIds = Sets.newHashSet(scheduleOrder(fragment.getRoot()));
+ tableScanIds.stream()
+ .filter(id -> !planNodeIdsWithSources.contains(id))
+ .forEach(id -> taskSources.add(new TaskSource(id, ImmutableSet.of(), true)));
+
+ log.info("Total serialized size of all table scan task sources: %s", succinctBytes(totalSerializedSizeInBytes));
+
+ // Populate remote sources - ShuffleRead & Broadcast.
+ ImmutableList.Builder shuffleTaskSources = ImmutableList.builder();
+ ImmutableList.Builder broadcastTaskSources = ImmutableList.builder();
+ AtomicLong nextSplitId = new AtomicLong();
+ taskSources.stream()
+ .flatMap(source -> source.getSplits().stream())
+ .mapToLong(ScheduledSplit::getSequenceId)
+ .max()
+ .ifPresent(id -> nextSplitId.set(id + 1));
+
+ for (RemoteSourceNode remoteSource : fragment.getRemoteSourceNodes()) {
+ for (PlanFragmentId sourceFragmentId : remoteSource.getSourceFragmentIds()) {
+ PrestoSparkShuffleReadDescriptor shuffleReadDescriptor =
+ nativeTaskInputs.getShuffleReadDescriptors().get(sourceFragmentId.toString());
+ if (shuffleReadDescriptor != null) {
+ ScheduledSplit split = new ScheduledSplit(nextSplitId.getAndIncrement(), remoteSource.getId(), new Split(REMOTE_CONNECTOR_ID, new RemoteTransactionHandle(), new RemoteSplit(
+ new Location(format("batch://%s?shuffleInfo=%s", DUMMY_TASK_ID,
+ shuffleInfoTranslator.createSerializedReadInfo(
+ shuffleInfoTranslator.createShuffleReadInfo(session, shuffleReadDescriptor)))),
+ DUMMY_TASK_ID)));
+ TaskSource source = new TaskSource(remoteSource.getId(), ImmutableSet.of(split), ImmutableSet.of(Lifespan.taskWide()), true);
+ shuffleTaskSources.add(source);
+ }
+
+ Broadcast> broadcast = nativeTaskInputs.getBroadcastInputs().get(sourceFragmentId.toString());
+ if (broadcast != null) {
+ Set splits =
+ ((List>) broadcast.value()).stream()
+ .map(PrestoSparkSerializedPage.class::cast)
+ .map(prestoSparkSerializedPage -> PrestoSparkUtils.toSerializedPage(prestoSparkSerializedPage))
+ .map(serializedPage -> pagesSerde.deserialize(serializedPage))
+ // Extract filePath.
+ .flatMap(page -> IntStream.range(0, page.getPositionCount())
+ .mapToObj(position -> VarcharType.VARCHAR.getObjectValue(null, page.getBlock(0), position)))
+ .map(String.class::cast)
+ .map(filePath -> new BroadcastFileInfo(filePath))
+ .map(broadcastFileInfo -> new ScheduledSplit(
+ nextSplitId.getAndIncrement(),
+ remoteSource.getId(),
+ new Split(
+ REMOTE_CONNECTOR_ID,
+ new RemoteTransactionHandle(),
+ new RemoteSplit(
+ new Location(
+ format("batch://%s?broadcastInfo=%s", DUMMY_TASK_ID, broadcastFileInfoJsonCodec.toJson(broadcastFileInfo))),
+ DUMMY_TASK_ID))))
+ .collect(toImmutableSet());
+
+ TaskSource source = new TaskSource(remoteSource.getId(), splits, ImmutableSet.of(Lifespan.taskWide()), true);
+ broadcastTaskSources.add(source);
+ }
+ }
+ }
+
+ taskSources.addAll(shuffleTaskSources.build());
+ taskSources.addAll(broadcastTaskSources.build());
+ return taskSources;
+ }
+
+ private Optional findTableWriteNode(PlanNode node)
+ {
+ return searchFrom(node)
+ .where(TableWriterNode.class::isInstance)
+ .findFirst();
+ }
+
+ private static class PrestoSparkNativeTaskOutputIterator
+ extends AbstractIterator>
+ implements IPrestoSparkTaskExecutor
+ {
+ private final int partitionId;
+ private final NativeExecutionTask nativeExecutionTask;
+ private Optional next = Optional.empty();
+ private final CollectionAccumulator taskInfoCollectionAccumulator;
+ private final Codec taskInfoCodec;
+ private final Class outputType;
+ private final PrestoSparkExecutionExceptionFactory executionExceptionFactory;
+ private final CpuTracker cpuTracker;
+ private final NativeExecutionProcess nativeExecutionProcess;
+ private final boolean terminateWithCoreWhenUnresponsive;
+ private final Duration terminateWithCoreTimeout;
+
+ public PrestoSparkNativeTaskOutputIterator(
+ int partitionId,
+ NativeExecutionTask nativeExecutionTask,
+ Class outputType,
+ CollectionAccumulator taskInfoCollectionAccumulator,
+ Codec taskInfoCodec,
+ PrestoSparkExecutionExceptionFactory executionExceptionFactory,
+ CpuTracker cpuTracker,
+ NativeExecutionProcess nativeExecutionProcess,
+ boolean terminateWithCoreWhenUnresponsive,
+ Duration terminateWithCoreTimeout)
+ {
+ this.partitionId = partitionId;
+ this.nativeExecutionTask = nativeExecutionTask;
+ this.taskInfoCollectionAccumulator = taskInfoCollectionAccumulator;
+ this.taskInfoCodec = taskInfoCodec;
+ this.outputType = outputType;
+ this.executionExceptionFactory = executionExceptionFactory;
+ this.cpuTracker = cpuTracker;
+ this.nativeExecutionProcess = requireNonNull(nativeExecutionProcess, "nativeExecutionProcess is null");
+ this.terminateWithCoreWhenUnresponsive = terminateWithCoreWhenUnresponsive;
+ this.terminateWithCoreTimeout = requireNonNull(terminateWithCoreTimeout, "terminateWithCoreTimeout is null");
+ }
+
+ /**
+ * This function is called by Spark's RDD layer to check if there are output pages
+ * There are 2 scenarios
+ * 1. ShuffleMap Task - Always returns false. But the internal function calls do all the work needed
+ * 2. Result Task - True until pages are available. False once all pages have been extracted
+ *
+ * @return if output is available
+ */
+ @Override
+ public boolean hasNext()
+ {
+ next = computeNext();
+ return next.isPresent();
+ }
+
+ /**
+ * This function returns the next available page fetched from CPP process
+ *
+ * Has 3 main responsibilities
+ * 1) wait-for-pages-or-completion
+ *
+ * The thread running this method will wait until either of the 3 conditions happen
+ * * 1. We get a page
+ * * 2. Task has finished successfully
+ * * 3. Task has finished with error
+ *
+ * For ShuffleMap Task, as of now, the CPP process returns no pages.
+ * So the thread will be in WAITING state till the CPP task is done and returns an Optional.empty()
+ * once the task has terminated
+ *
+ * For a Result Task, this function will return pages retrieved from CPP side once we got them.
+ * Once all the pages have been read and the task has been terminates
+ *
+ * 2) Exception handling
+ * The function also checks if the task has finished
+ * with exceptions and throws the appropriate exception back to spark's RDD processing
+ * layer
+ *
+ * 3) Statistics collection
+ * For both, when the task finished successfully or with exception, it tries to collect
+ * statistics if TaskInfo object is available
+ *
+ * @return Optional outputPage
+ */
+ private Optional computeNext()
+ {
+ try {
+ Object taskFinishedOrHasResult = nativeExecutionTask.getTaskFinishedOrHasResult();
+ // Blocking wait if task is still running or hasn't produced any output page
+ synchronized (taskFinishedOrHasResult) {
+ while (!nativeExecutionTask.isTaskDone() && !nativeExecutionTask.hasResult()) {
+ taskFinishedOrHasResult.wait();
+ }
+ }
+
+ // For ShuffleMap Task, this will always return Optional.empty()
+ Optional pageOptional = nativeExecutionTask.pollResult();
+
+ if (pageOptional.isPresent()) {
+ return pageOptional;
+ }
+
+ // Double check if current task's already done (since thread could be awoken by either having output or task is done above)
+ synchronized (taskFinishedOrHasResult) {
+ while (!nativeExecutionTask.isTaskDone()) {
+ taskFinishedOrHasResult.wait();
+ }
+ }
+
+ Optional taskInfo = nativeExecutionTask.getTaskInfo();
+
+ processTaskInfoForErrorsOrCompletion(taskInfo.get());
+ }
+ catch (RuntimeException ex) {
+ // For a failed task, if taskInfo is present we still want to log the metrics
+ completeTask(false, taskInfoCollectionAccumulator, nativeExecutionTask, taskInfoCodec, cpuTracker);
+ throw executionExceptionFactory.toPrestoSparkExecutionException(processFailure(
+ ex,
+ nativeExecutionProcess,
+ terminateWithCoreWhenUnresponsive,
+ terminateWithCoreTimeout));
+ }
+ catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw new RuntimeException(e);
+ }
+
+ // Reaching here marks the end of task processing
+ completeTask(true, taskInfoCollectionAccumulator, nativeExecutionTask, taskInfoCodec, cpuTracker);
+ return Optional.empty();
+ }
+
+ @Override
+ public Tuple2 next()
+ {
+ // Result Tasks only have outputType of PrestoSparkSerializedPage.
+ checkArgument(outputType == PrestoSparkSerializedPage.class,
+ format("PrestoSparkNativeTaskExecutorFactory only outputType=PrestoSparkSerializedPage " +
+ "But tried to extract outputType=%s", outputType));
+
+ // Set partition ID to help match the results to the task on the driver for debugging.
+ MutablePartitionId mutablePartitionId = new MutablePartitionId();
+ mutablePartitionId.setPartition(partitionId);
+ return new Tuple2<>(mutablePartitionId, (T) toPrestoSparkSerializedPage(next.get()));
+ }
+ }
+
+ private static RuntimeException processFailure(
+ RuntimeException failure,
+ NativeExecutionProcess process,
+ boolean terminateWithCoreWhenUnresponsive,
+ Duration terminateWithCoreTimeout)
+ {
+ if (isCommunicationLoss(failure)) {
+ PrestoTransportException transportException = (PrestoTransportException) failure;
+ String message;
+ // lost communication with the native execution process
+ if (process.isAlive()) {
+ // process is unresponsive
+ if (terminateWithCoreWhenUnresponsive) {
+ process.terminateWithCore(terminateWithCoreTimeout);
+ }
+ message = "Native execution process is alive but unresponsive";
+ }
+ else {
+ message = "Native execution process is dead";
+ String crashReport = process.getCrashReport();
+ if (!crashReport.isEmpty()) {
+ message += ":\n" + crashReport;
+ }
+ }
+
+ return new PrestoTransportException(
+ transportException::getErrorCode,
+ transportException.getRemoteHost(),
+ message,
+ failure);
+ }
+ return failure;
+ }
+
+ private static boolean isCommunicationLoss(RuntimeException failure)
+ {
+ if (!(failure instanceof PrestoTransportException)) {
+ return false;
+ }
+ PrestoTransportException transportException = (PrestoTransportException) failure;
+ return TOO_MANY_REQUESTS_FAILED.toErrorCode().equals(transportException.getErrorCode());
+ }
+}
diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/PrestoSparkRddFactory.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/PrestoSparkRddFactory.java
index f06f859b9996a..7990d4ca83496 100644
--- a/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/PrestoSparkRddFactory.java
+++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/PrestoSparkRddFactory.java
@@ -23,6 +23,7 @@
import com.facebook.presto.spark.PrestoSparkTaskDescriptor;
import com.facebook.presto.spark.classloader_interface.MutablePartitionId;
import com.facebook.presto.spark.classloader_interface.PrestoSparkMutableRow;
+import com.facebook.presto.spark.classloader_interface.PrestoSparkNativeTaskRdd;
import com.facebook.presto.spark.classloader_interface.PrestoSparkShuffleStats;
import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskExecutorFactoryProvider;
import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskOutput;
@@ -42,6 +43,7 @@
import com.facebook.presto.split.CloseableSplitSourceProvider;
import com.facebook.presto.split.SplitManager;
import com.facebook.presto.split.SplitSource;
+import com.facebook.presto.sql.analyzer.FeaturesConfig;
import com.facebook.presto.sql.planner.PartitioningProviderManager;
import com.facebook.presto.sql.planner.PlanFragment;
import com.facebook.presto.sql.planner.SplitSourceFactory;
@@ -99,18 +101,21 @@ public class PrestoSparkRddFactory
private final PartitioningProviderManager partitioningProviderManager;
private final JsonCodec taskDescriptorJsonCodec;
private final Codec taskSourceCodec;
+ private final FeaturesConfig featuresConfig;
@Inject
public PrestoSparkRddFactory(
SplitManager splitManager,
PartitioningProviderManager partitioningProviderManager,
JsonCodec taskDescriptorJsonCodec,
- Codec taskSourceCodec)
+ Codec taskSourceCodec,
+ FeaturesConfig featuresConfig)
{
this.splitManager = requireNonNull(splitManager, "splitManager is null");
this.partitioningProviderManager = requireNonNull(partitioningProviderManager, "partitioningProviderManager is null");
this.taskDescriptorJsonCodec = requireNonNull(taskDescriptorJsonCodec, "taskDescriptorJsonCodec is null");
this.taskSourceCodec = requireNonNull(taskSourceCodec, "taskSourceCodec is null");
+ this.featuresConfig = requireNonNull(featuresConfig, "featuresConfig is null");
}
public JavaPairRDD createSparkRdd(
@@ -250,14 +255,26 @@ else if (rddInputs.size() == 0) {
taskSourceRdd = Optional.empty();
}
- return JavaPairRDD.fromRDD(
- PrestoSparkTaskRdd.create(
- sparkContext.sc(),
- taskSourceRdd,
- shuffleInputRddMap,
- taskProcessor).setName(getRDDName(fragment.getId().getId())),
- classTag(MutablePartitionId.class),
- classTag(outputType));
+ if (featuresConfig.isNativeExecutionEnabled()) {
+ return JavaPairRDD.fromRDD(
+ PrestoSparkNativeTaskRdd.create(
+ sparkContext.sc(),
+ taskSourceRdd,
+ shuffleInputRddMap,
+ taskProcessor).setName(getRDDName(fragment.getId().getId())),
+ classTag(MutablePartitionId.class),
+ classTag(outputType));
+ }
+ else {
+ return JavaPairRDD.fromRDD(
+ PrestoSparkTaskRdd.create(
+ sparkContext.sc(),
+ taskSourceRdd,
+ shuffleInputRddMap,
+ taskProcessor).setName(getRDDName(fragment.getId().getId())),
+ classTag(MutablePartitionId.class),
+ classTag(outputType));
+ }
}
private PrestoSparkTaskSourceRdd createTaskSourcesRdd(
diff --git a/presto-spark-base/src/test/java/com/facebook/presto/spark/PrestoSparkQueryRunner.java b/presto-spark-base/src/test/java/com/facebook/presto/spark/PrestoSparkQueryRunner.java
index 31d40b594a2c5..31b90de8859aa 100644
--- a/presto-spark-base/src/test/java/com/facebook/presto/spark/PrestoSparkQueryRunner.java
+++ b/presto-spark-base/src/test/java/com/facebook/presto/spark/PrestoSparkQueryRunner.java
@@ -57,11 +57,13 @@
import com.facebook.presto.spark.classloader_interface.PrestoSparkSession;
import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskExecutorFactoryProvider;
import com.facebook.presto.spark.execution.AbstractPrestoSparkQueryExecution;
+import com.facebook.presto.spark.execution.nativeprocess.NativeExecutionModule;
import com.facebook.presto.spi.NodeManager;
import com.facebook.presto.spi.Plugin;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.eventlistener.EventListener;
import com.facebook.presto.spi.function.FunctionImplementationType;
+import com.facebook.presto.spi.security.Identity;
import com.facebook.presto.spi.security.PrincipalType;
import com.facebook.presto.split.PageSourceManager;
import com.facebook.presto.split.SplitManager;
@@ -243,7 +245,7 @@ public static PrestoSparkQueryRunner createHivePrestoSparkQueryRunner(Iterable SERVER_INFO_JSON_CODEC = JsonCodec.jsonCodec(ServerInfo.class);
+
+ @Test
+ public void testNativeProcessIsAlive()
+ {
+ Session session = testSessionBuilder().build();
+ NativeExecutionProcessFactory factory = createNativeExecutionProcessFactory();
+ NativeExecutionProcess process = factory.getNativeExecutionProcess(session);
+ // Simulate the process is closed (crashed)
+ process.close();
+ assertFalse(process.isAlive());
+ }
+
+ @Test
+ public void testNativeProcessRelaunch()
+ {
+ Session session = testSessionBuilder().build();
+ NativeExecutionProcessFactory factory = createNativeExecutionProcessFactory();
+ NativeExecutionProcess process = factory.getNativeExecutionProcess(session);
+ // Simulate the process is closed (crashed)
+ process.close();
+ assertFalse(process.isAlive());
+ NativeExecutionProcess process2 = factory.getNativeExecutionProcess(session);
+ // Expecting the factory re-created a new process object so that the process and process2
+ // should be two different objects
+ assertNotSame(process2, process);
+ }
+
+ @Test
+ public void testNativeProcessShutdown()
+ {
+ Session session = testSessionBuilder().build();
+ NativeExecutionProcessFactory factory = createNativeExecutionProcessFactory();
+ // Set the maxRetryDuration to 0 ms to allow the RequestErrorTracker failing immediately
+ NativeExecutionProcess process = factory.createNativeExecutionProcess(session, new Duration(0, TimeUnit.MILLISECONDS));
+ Throwable exception = expectThrows(PrestoSparkFatalException.class, process::start);
+ assertTrue(exception.getMessage().contains("Native process launch failed with multiple retries"));
+ assertFalse(process.isAlive());
+ }
+
+ private NativeExecutionProcessFactory createNativeExecutionProcessFactory()
+ {
+ TaskId taskId = new TaskId("testid", 0, 0, 0, 0);
+ ScheduledExecutorService errorScheduler = newScheduledThreadPool(4);
+ PrestoSparkWorkerProperty workerProperty = new PrestoSparkWorkerProperty(
+ new NativeExecutionConnectorConfig(),
+ new NativeExecutionNodeConfig(),
+ new NativeExecutionSystemConfig(),
+ new NativeExecutionVeloxConfig());
+ NativeExecutionProcessFactory factory = new NativeExecutionProcessFactory(
+ new TestPrestoSparkHttpClient.TestingHttpClient(
+ errorScheduler,
+ new TestPrestoSparkHttpClient.TestingResponseManager(taskId.toString(), new TestPrestoSparkHttpClient.FailureRetryResponseManager(5))),
+ newSingleThreadExecutor(),
+ errorScheduler,
+ SERVER_INFO_JSON_CODEC,
+ workerProperty,
+ new FeaturesConfig().setNativeExecutionExecutablePath("/bin/echo"));
+ return factory;
+ }
+}
diff --git a/presto-spark-base/src/test/java/com/facebook/presto/spark/execution/http/TestPrestoSparkHttpClient.java b/presto-spark-base/src/test/java/com/facebook/presto/spark/execution/http/TestPrestoSparkHttpClient.java
new file mode 100644
index 0000000000000..f3f516aede37d
--- /dev/null
+++ b/presto-spark-base/src/test/java/com/facebook/presto/spark/execution/http/TestPrestoSparkHttpClient.java
@@ -0,0 +1,1432 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.spark.execution.http;
+
+import com.facebook.airlift.http.client.HeaderName;
+import com.facebook.airlift.http.client.HttpClient;
+import com.facebook.airlift.http.client.HttpStatus;
+import com.facebook.airlift.http.client.Request;
+import com.facebook.airlift.http.client.RequestStats;
+import com.facebook.airlift.http.client.Response;
+import com.facebook.airlift.http.client.ResponseHandler;
+import com.facebook.airlift.json.JsonCodec;
+import com.facebook.presto.client.ServerInfo;
+import com.facebook.presto.execution.QueryManagerConfig;
+import com.facebook.presto.execution.TaskId;
+import com.facebook.presto.execution.TaskInfo;
+import com.facebook.presto.execution.TaskManagerConfig;
+import com.facebook.presto.execution.TaskSource;
+import com.facebook.presto.execution.TaskState;
+import com.facebook.presto.execution.TaskStatus;
+import com.facebook.presto.execution.scheduler.TableWriteInfo;
+import com.facebook.presto.operator.PageBufferClient;
+import com.facebook.presto.operator.PageTransportErrorException;
+import com.facebook.presto.operator.TaskStats;
+import com.facebook.presto.server.smile.BaseResponse;
+import com.facebook.presto.spark.execution.nativeprocess.HttpNativeExecutionTaskInfoFetcher;
+import com.facebook.presto.spark.execution.nativeprocess.HttpNativeExecutionTaskResultFetcher;
+import com.facebook.presto.spark.execution.nativeprocess.NativeExecutionProcess;
+import com.facebook.presto.spark.execution.nativeprocess.NativeExecutionProcessFactory;
+import com.facebook.presto.spark.execution.property.NativeExecutionConnectorConfig;
+import com.facebook.presto.spark.execution.property.NativeExecutionNodeConfig;
+import com.facebook.presto.spark.execution.property.NativeExecutionSystemConfig;
+import com.facebook.presto.spark.execution.property.NativeExecutionVeloxConfig;
+import com.facebook.presto.spark.execution.property.PrestoSparkWorkerProperty;
+import com.facebook.presto.spark.execution.task.NativeExecutionTask;
+import com.facebook.presto.spark.execution.task.NativeExecutionTaskFactory;
+import com.facebook.presto.spi.PrestoException;
+import com.facebook.presto.spi.PrestoTransportException;
+import com.facebook.presto.spi.page.PageCodecMarker;
+import com.facebook.presto.spi.page.PagesSerdeUtil;
+import com.facebook.presto.spi.page.SerializedPage;
+import com.facebook.presto.sql.analyzer.FeaturesConfig;
+import com.facebook.presto.sql.planner.PlanFragment;
+import com.facebook.presto.testing.TestingSession;
+import com.google.common.collect.ArrayListMultimap;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.ListMultimap;
+import com.google.common.net.MediaType;
+import com.google.common.util.concurrent.AbstractFuture;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.SettableFuture;
+import io.airlift.slice.DynamicSliceOutput;
+import io.airlift.slice.Slice;
+import io.airlift.slice.Slices;
+import io.airlift.units.DataSize;
+import io.airlift.units.Duration;
+import org.testng.annotations.AfterClass;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Test;
+
+import java.io.ByteArrayInputStream;
+import java.io.InputStream;
+import java.net.URI;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Optional;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+import java.util.regex.Pattern;
+
+import static com.facebook.airlift.http.client.HttpUriBuilder.uriBuilder;
+import static com.facebook.airlift.http.client.HttpUriBuilder.uriBuilderFrom;
+import static com.facebook.presto.PrestoMediaTypes.PRESTO_PAGES_TYPE;
+import static com.facebook.presto.client.NodeVersion.UNKNOWN;
+import static com.facebook.presto.client.PrestoHeaders.PRESTO_BUFFER_COMPLETE;
+import static com.facebook.presto.client.PrestoHeaders.PRESTO_PAGE_NEXT_TOKEN;
+import static com.facebook.presto.client.PrestoHeaders.PRESTO_PAGE_TOKEN;
+import static com.facebook.presto.client.PrestoHeaders.PRESTO_TASK_INSTANCE_ID;
+import static com.facebook.presto.execution.TaskTestUtils.createPlanFragment;
+import static com.facebook.presto.execution.buffer.OutputBuffers.BufferType.PARTITIONED;
+import static com.facebook.presto.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers;
+import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
+import static com.facebook.presto.testing.TestingSession.testSessionBuilder;
+import static com.google.common.net.HttpHeaders.CONTENT_TYPE;
+import static io.airlift.units.DataSize.Unit.MEGABYTE;
+import static java.lang.String.format;
+import static java.util.Objects.requireNonNull;
+import static java.util.concurrent.Executors.newScheduledThreadPool;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertFalse;
+import static org.testng.Assert.assertNotNull;
+import static org.testng.Assert.assertTrue;
+import static org.testng.Assert.expectThrows;
+import static org.testng.Assert.fail;
+
+public class TestPrestoSparkHttpClient
+{
+ private static final String TASK_ROOT_PATH = "/v1/task";
+ private static final URI BASE_URI = uriBuilder()
+ .scheme("http")
+ .host("localhost")
+ .port(8080)
+ .build();
+ private static final Duration NO_DURATION = new Duration(0, TimeUnit.MILLISECONDS);
+ private static final JsonCodec TASK_INFO_JSON_CODEC = JsonCodec.jsonCodec(TaskInfo.class);
+ private static final JsonCodec PLAN_FRAGMENT_JSON_CODEC = JsonCodec.jsonCodec(PlanFragment.class);
+ private static final JsonCodec TASK_UPDATE_REQUEST_JSON_CODEC = JsonCodec.jsonCodec(BatchTaskUpdateRequest.class);
+ private static final JsonCodec SERVER_INFO_JSON_CODEC = JsonCodec.jsonCodec(ServerInfo.class);
+
+ private ScheduledExecutorService scheduledExecutorService;
+
+ @BeforeClass
+ public void beforeClass()
+ {
+ scheduledExecutorService = newScheduledThreadPool(4);
+ }
+
+ @AfterClass(alwaysRun = true)
+ public void afterClass()
+ {
+ if (scheduledExecutorService != null) {
+ scheduledExecutorService.shutdownNow();
+ scheduledExecutorService = null;
+ }
+ }
+
+ @Test
+ public void testResultGet()
+ {
+ TaskId taskId = new TaskId(
+ "testid",
+ 0,
+ 0,
+ 0,
+ 0);
+
+ PrestoSparkHttpTaskClient workerClient = createWorkerClient(taskId);
+ ListenableFuture future = workerClient.getResults(
+ 0,
+ new DataSize(32, MEGABYTE));
+ try {
+ PageBufferClient.PagesResponse page = future.get();
+ assertEquals(0, page.getToken());
+ assertTrue(page.isClientComplete());
+ assertEquals(taskId.toString(), page.getTaskInstanceId());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail();
+ }
+ }
+
+ @Test
+ public void testResultAcknowledge()
+ {
+ TaskId taskId = new TaskId("testid", 0, 0, 0, 0);
+
+ PrestoSparkHttpTaskClient workerClient = createWorkerClient(taskId);
+ workerClient.acknowledgeResultsAsync(1);
+ }
+
+ private PrestoSparkHttpTaskClient createWorkerClient(TaskId taskId)
+ {
+ return createWorkerClient(taskId, new TestingHttpClient(scheduledExecutorService, new TestingResponseManager(taskId.toString())));
+ }
+
+ private PrestoSparkHttpTaskClient createWorkerClient(TaskId taskId, TestingHttpClient httpClient)
+ {
+ return new PrestoSparkHttpTaskClient(
+ httpClient,
+ taskId,
+ BASE_URI,
+ TASK_INFO_JSON_CODEC,
+ PLAN_FRAGMENT_JSON_CODEC,
+ TASK_UPDATE_REQUEST_JSON_CODEC,
+ new Duration(1, TimeUnit.SECONDS),
+ scheduledExecutorService,
+ scheduledExecutorService,
+ new Duration(1, TimeUnit.SECONDS));
+ }
+
+ HttpNativeExecutionTaskResultFetcher createResultFetcher(PrestoSparkHttpTaskClient workerClient)
+ {
+ return createResultFetcher(workerClient, new Object());
+ }
+
+ HttpNativeExecutionTaskResultFetcher createResultFetcher(PrestoSparkHttpTaskClient workerClient, Object lock)
+ {
+ return new HttpNativeExecutionTaskResultFetcher(
+ scheduledExecutorService,
+ workerClient,
+ lock);
+ }
+
+ @Test
+ public void testResultAbort()
+ {
+ TaskId taskId = new TaskId("testid", 0, 0, 0, 0);
+
+ PrestoSparkHttpTaskClient workerClient = createWorkerClient(taskId);
+ ListenableFuture> future = workerClient.abortResultsAsync();
+ try {
+ future.get();
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail();
+ }
+ }
+
+ @Test
+ public void testGetTaskInfo()
+ {
+ TaskId taskId = new TaskId("testid", 0, 0, 0, 0);
+
+ PrestoSparkHttpTaskClient workerClient = createWorkerClient(taskId);
+ try {
+ TaskInfo taskInfo = workerClient.getTaskInfo();
+ assertEquals(taskInfo.getTaskId().toString(), taskId.toString());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail();
+ }
+ }
+
+ @Test
+ public void testUpdateTask()
+ {
+ TaskId taskId = new TaskId("testid", 0, 0, 0, 0);
+
+ PrestoSparkHttpTaskClient workerClient = createWorkerClient(taskId);
+
+ List sources = new ArrayList<>();
+
+ try {
+ TaskInfo taskInfo = workerClient.updateTask(
+ sources,
+ createPlanFragment(),
+ new TableWriteInfo(Optional.empty(), Optional.empty()),
+ Optional.empty(),
+ Optional.empty(),
+ TestingSession.testSessionBuilder().build(),
+ createInitialEmptyOutputBuffers(PARTITIONED));
+ assertEquals(taskInfo.getTaskId().toString(), taskId.toString());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail();
+ }
+ }
+
+ @Test
+ public void testUpdateTaskUnexpectedResponse()
+ {
+ TaskId taskId = new TaskId("testid", 0, 0, 0, 0);
+ PrestoSparkHttpTaskClient workerClient = createWorkerClient(
+ taskId,
+ new TestingHttpClient(scheduledExecutorService, new TestingResponseManager(taskId.toString(), new UnexpectedResponseTaskInfoRetryResponseManager())));
+ assertThatThrownBy(() -> workerClient.updateTask(
+ new ArrayList<>(),
+ createPlanFragment(),
+ new TableWriteInfo(Optional.empty(), Optional.empty()),
+ Optional.empty(),
+ Optional.empty(),
+ TestingSession.testSessionBuilder().build(),
+ createInitialEmptyOutputBuffers(PARTITIONED)))
+ .isInstanceOf(PrestoException.class)
+ .hasMessageContaining("500");
+ }
+
+ @Test
+ public void testUpdateTaskWithRetries()
+ {
+ TaskId taskId = new TaskId("testid", 0, 0, 0, 0);
+ PrestoSparkHttpTaskClient workerClient = createWorkerClient(
+ taskId,
+ new TestingHttpClient(scheduledExecutorService, new TestingResponseManager(taskId.toString(), new FailureRetryTaskInfoResponseManager(2))));
+ workerClient.updateTask(
+ new ArrayList<>(),
+ createPlanFragment(),
+ new TableWriteInfo(Optional.empty(), Optional.empty()),
+ Optional.empty(),
+ Optional.empty(),
+ TestingSession.testSessionBuilder().build(),
+ createInitialEmptyOutputBuffers(PARTITIONED));
+ }
+
+ @Test
+ public void testGetServerInfo()
+ {
+ TaskId taskId = new TaskId("testid", 0, 0, 0, 0);
+ ServerInfo expected = new ServerInfo(UNKNOWN, "test", true, false, Optional.of(Duration.valueOf("2m")));
+
+ PrestoSparkHttpServerClient workerClient = new PrestoSparkHttpServerClient(
+ new TestingHttpClient(scheduledExecutorService, new TestingResponseManager(taskId.toString())),
+ BASE_URI,
+ SERVER_INFO_JSON_CODEC);
+ ListenableFuture> future = workerClient.getServerInfo();
+ try {
+ ServerInfo serverInfo = future.get().getValue();
+ assertEquals(serverInfo, expected);
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail();
+ }
+ }
+
+ @Test
+ public void testGetServerInfoWithRetry()
+ {
+ TaskId taskId = new TaskId("testid", 0, 0, 0, 0);
+ ServerInfo expected = new ServerInfo(UNKNOWN, "test", true, false, Optional.of(Duration.valueOf("2m")));
+ Duration maxTimeout = new Duration(1, TimeUnit.MINUTES);
+ NativeExecutionProcess process = createNativeExecutionProcess(
+ maxTimeout,
+ new TestingResponseManager(taskId.toString(), new FailureRetryResponseManager(5)));
+
+ SettableFuture future = process.getServerInfoWithRetry();
+ try {
+ ServerInfo serverInfo = future.get();
+ assertEquals(serverInfo, expected);
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail();
+ }
+ }
+
+ @Test
+ public void testGetServerInfoWithRetryTimeout()
+ {
+ TaskId taskId = new TaskId("testid", 0, 0, 0, 0);
+ Duration maxTimeout = new Duration(0, TimeUnit.MILLISECONDS);
+ NativeExecutionProcess process = createNativeExecutionProcess(
+ maxTimeout,
+ new TestingResponseManager(taskId.toString(), new FailureRetryResponseManager(5)));
+
+ SettableFuture future = process.getServerInfoWithRetry();
+ Exception exception = expectThrows(ExecutionException.class, future::get);
+ assertTrue(exception.getMessage().contains("Native process launch failed with multiple retries"));
+ }
+
+ @Test
+ public void testResultFetcher()
+ {
+ TaskId taskId = new TaskId("testid", 0, 0, 0, 0);
+
+ PrestoSparkHttpTaskClient workerClient = createWorkerClient(taskId);
+ HttpNativeExecutionTaskResultFetcher taskResultFetcher = createResultFetcher(workerClient);
+ taskResultFetcher.start();
+ try {
+ List pages = new ArrayList<>();
+ Optional page = taskResultFetcher.pollPage();
+ while (page.isPresent()) {
+ pages.add(page.get());
+ page = taskResultFetcher.pollPage();
+ }
+
+ assertEquals(1, pages.size());
+ assertEquals(0, pages.get(0).getSizeInBytes());
+ }
+ catch (InterruptedException e) {
+ e.printStackTrace();
+ fail();
+ }
+ }
+
+ private List fetchResults(HttpNativeExecutionTaskResultFetcher taskResultFetcher, int numPages)
+ throws InterruptedException
+ {
+ List pages = new ArrayList<>();
+ for (int i = 0; i < 1_000 && pages.size() < numPages; ++i) {
+ Optional page = taskResultFetcher.pollPage();
+ if (page.isPresent()) {
+ pages.add(page.get());
+ }
+ }
+ return pages;
+ }
+
+ @Test
+ public void testResultFetcherMultipleNonEmptyResults()
+ {
+ TaskId taskId = new TaskId("testid", 0, 0, 0, 0);
+ int serializedPageSize = (int) new DataSize(1, MEGABYTE).toBytes();
+ int numPages = 10;
+ PrestoSparkHttpTaskClient workerClient = createWorkerClient(
+ taskId,
+ new TestingHttpClient(
+ scheduledExecutorService,
+ new TestingResponseManager(taskId.toString(), new TestingResponseManager.TestingResultResponseManager()
+ {
+ private int requestCount;
+
+ @Override
+ public Response createResultResponse(String taskId)
+ throws PageTransportErrorException
+ {
+ requestCount++;
+ if (requestCount < numPages) {
+ return createResultResponseHelper(
+ HttpStatus.OK,
+ taskId,
+ requestCount - 1,
+ requestCount,
+ false,
+ serializedPageSize);
+ }
+ else if (requestCount == numPages) {
+ return createResultResponseHelper(
+ HttpStatus.OK,
+ taskId,
+ requestCount - 1,
+ requestCount,
+ true,
+ serializedPageSize);
+ }
+ else {
+ fail("Retrieving results after buffer completion");
+ return null;
+ }
+ }
+ })));
+ HttpNativeExecutionTaskResultFetcher taskResultFetcher = createResultFetcher(workerClient);
+ taskResultFetcher.start();
+ try {
+ List pages = fetchResults(taskResultFetcher, numPages);
+
+ assertEquals(numPages, pages.size());
+ for (int i = 0; i < numPages; i++) {
+ assertEquals(pages.get(i).getSizeInBytes(), serializedPageSize);
+ }
+ }
+ catch (InterruptedException e) {
+ e.printStackTrace();
+ fail();
+ }
+ }
+
+ private static class BreakingLimitResponseManager
+ extends TestingResponseManager.TestingResultResponseManager
+ {
+ private final int serializedPageSize;
+ private final int numPages;
+
+ private int requestCount;
+
+ public BreakingLimitResponseManager(int serializedPageSize, int numPages)
+ {
+ this.serializedPageSize = serializedPageSize;
+ this.numPages = numPages;
+ }
+
+ @Override
+ public Response createResultResponse(String taskId)
+ throws PageTransportErrorException
+ {
+ requestCount++;
+ if (requestCount < numPages) {
+ return createResultResponseHelper(
+ HttpStatus.OK,
+ taskId,
+ requestCount - 1,
+ requestCount,
+ false,
+ serializedPageSize);
+ }
+ else if (requestCount == numPages) {
+ return createResultResponseHelper(
+ HttpStatus.OK,
+ taskId,
+ requestCount - 1,
+ requestCount,
+ true,
+ serializedPageSize);
+ }
+ else {
+ fail("Retrieving results after buffer completion");
+ return null;
+ }
+ }
+
+ public int getRemainingPageCount()
+ {
+ return numPages - requestCount;
+ }
+ }
+
+ @Test
+ public void testResultFetcherExceedingBufferLimit()
+ {
+ int numPages = 10;
+ int serializedPageSize = (int) new DataSize(32, MEGABYTE).toBytes();
+ TaskId taskId = new TaskId("testid", 0, 0, 0, 0);
+
+ BreakingLimitResponseManager breakingLimitResponseManager =
+ new BreakingLimitResponseManager(serializedPageSize, numPages);
+
+ PrestoSparkHttpTaskClient workerClient = createWorkerClient(
+ taskId,
+ new TestingHttpClient(
+ scheduledExecutorService,
+ new TestingResponseManager(
+ taskId.toString(),
+ breakingLimitResponseManager)));
+ HttpNativeExecutionTaskResultFetcher taskResultFetcher = createResultFetcher(workerClient);
+ taskResultFetcher.start();
+ try {
+ Optional page = Optional.empty();
+ while (!page.isPresent()) {
+ page = taskResultFetcher.pollPage();
+ }
+ // Wait a bit for fetches to overwhelm memory.
+ Thread.sleep(5000);
+ assertEquals(breakingLimitResponseManager.getRemainingPageCount(), 5);
+ List pages = new ArrayList<>();
+ pages.add(page.get());
+ while (pages.size() < numPages) {
+ page = taskResultFetcher.pollPage();
+ page.ifPresent(pages::add);
+ }
+
+ assertEquals(numPages, pages.size());
+ for (int i = 0; i < numPages; i++) {
+ assertEquals(pages.get(i).getSizeInBytes(), serializedPageSize);
+ }
+ }
+ catch (InterruptedException e) {
+ e.printStackTrace();
+ fail();
+ }
+ }
+
+ private static class TimeoutResponseManager
+ extends TestingResponseManager.TestingResultResponseManager
+ {
+ private final int serializedPageSize;
+ private final int numPages;
+ private final int numInitialTimeouts;
+
+ private int requestCount;
+ private int timeoutCount;
+
+ public TimeoutResponseManager(int serializedPageSize, int numPages, int numInitialTimeouts)
+ {
+ this.serializedPageSize = serializedPageSize;
+ this.numPages = numPages;
+ this.numInitialTimeouts = numInitialTimeouts;
+ }
+
+ @Override
+ public Response createResultResponse(String taskId)
+ throws PageTransportErrorException
+ {
+ if (++timeoutCount <= numInitialTimeouts) {
+ throw new RuntimeException("test failure");
+ }
+ requestCount++;
+ if (requestCount < numPages) {
+ return createResultResponseHelper(
+ HttpStatus.OK,
+ taskId,
+ requestCount - 1,
+ requestCount,
+ false,
+ serializedPageSize);
+ }
+ else if (requestCount == numPages) {
+ return createResultResponseHelper(
+ HttpStatus.OK,
+ taskId,
+ requestCount - 1,
+ requestCount,
+ true,
+ serializedPageSize);
+ }
+ else {
+ fail("Retrieving results after buffer completion");
+ return null;
+ }
+ }
+ }
+
+ private static class PrestoExceptionResponseManager
+ extends TestingResponseManager.TestingResultResponseManager
+ {
+ private int requestCount;
+
+ @Override
+ public Response createResultResponse(String taskId)
+ throws PageTransportErrorException
+ {
+ if (requestCount == 0) {
+ requestCount++;
+ throw new PrestoException(GENERIC_INTERNAL_ERROR, "non retriable failure");
+ }
+ throw new RuntimeException("expected to be called only once");
+ }
+ }
+
+ @Test
+ public void testResultFetcherTransportErrorRecovery()
+ {
+ int numPages = 10;
+ int serializedPageSize = 0;
+ // Transport error count less than MAX_TRANSPORT_ERROR_RETRIES (5).
+ // Expecting recovery from failed requests
+ int numTransportErrors = 3;
+
+ TaskId taskId = new TaskId("testid", 0, 0, 0, 0);
+
+ TimeoutResponseManager timeoutResponseManager =
+ new TimeoutResponseManager(serializedPageSize, numPages, numTransportErrors);
+
+ PrestoSparkHttpTaskClient workerClient = createWorkerClient(
+ taskId,
+ new TestingHttpClient(
+ scheduledExecutorService,
+ new TestingResponseManager(
+ taskId.toString(),
+ timeoutResponseManager)));
+ HttpNativeExecutionTaskResultFetcher taskResultFetcher = createResultFetcher(workerClient);
+ taskResultFetcher.start();
+ try {
+ List pages = fetchResults(taskResultFetcher, numPages);
+
+ assertEquals(pages.size(), numPages);
+ for (int i = 0; i < numPages; i++) {
+ assertEquals(pages.get(i).getSizeInBytes(), serializedPageSize);
+ }
+ }
+ catch (InterruptedException e) {
+ e.printStackTrace();
+ fail();
+ }
+ }
+
+ @Test
+ public void testResultFetcherTransportErrorFail()
+ throws InterruptedException
+ {
+ TaskId taskId = new TaskId("testid", 0, 0, 0, 0);
+
+ PrestoSparkHttpTaskClient workerClient = createWorkerClient(
+ taskId,
+ new TestingHttpClient(
+ scheduledExecutorService,
+ new TestingResponseManager(taskId.toString(), new TimeoutResponseManager(0, 10, 10))));
+ HttpNativeExecutionTaskResultFetcher taskResultFetcher = createResultFetcher(workerClient);
+ taskResultFetcher.start();
+ try {
+ for (int i = 0; i < 1_000; ++i) {
+ taskResultFetcher.pollPage();
+ }
+ fail("Expected an exception");
+ }
+ catch (PrestoTransportException e) {
+ assertTrue(e.getMessage().startsWith("getResults encountered too many errors talking to native process"));
+ }
+ }
+
+ @Test
+ public void testResultFetcherPrestoException()
+ throws InterruptedException
+ {
+ TaskId taskId = new TaskId("testid", 0, 0, 0, 0);
+ PrestoSparkHttpTaskClient workerClient = createWorkerClient(
+ taskId,
+ new TestingHttpClient(
+ scheduledExecutorService,
+ new TestingResponseManager(taskId.toString(), new PrestoExceptionResponseManager())));
+ Object monitor = new Object();
+ HttpNativeExecutionTaskResultFetcher taskResultFetcher = createResultFetcher(workerClient, monitor);
+ taskResultFetcher.start();
+ synchronized (monitor) {
+ try {
+ while (!taskResultFetcher.hasPage()) {
+ monitor.wait();
+ }
+ }
+ catch (RuntimeException ignored) {
+ }
+ }
+ assertThatThrownBy(taskResultFetcher::pollPage)
+ .isInstanceOf(PrestoException.class)
+ .hasMessage("non retriable failure");
+ }
+
+ @Test
+ public void testResultFetcherWaitOnSignal()
+ {
+ TaskId taskId = new TaskId("testid", 0, 0, 0, 0);
+ Object lock = new Object();
+
+ PrestoSparkHttpTaskClient workerClient = createWorkerClient(taskId);
+ HttpNativeExecutionTaskResultFetcher taskResultFetcher = createResultFetcher(workerClient, lock);
+ taskResultFetcher.start();
+ try {
+ synchronized (lock) {
+ while (!taskResultFetcher.hasPage()) {
+ lock.wait();
+ }
+ }
+ assertTrue(taskResultFetcher.hasPage());
+ }
+ catch (InterruptedException e) {
+ e.printStackTrace();
+ fail();
+ }
+ }
+
+ @Test
+ public void testInfoFetcher()
+ {
+ TaskId taskId = new TaskId("testid", 0, 0, 0, 0);
+
+ Duration fetchInterval = new Duration(1, TimeUnit.SECONDS);
+ HttpNativeExecutionTaskInfoFetcher taskInfoFetcher = createTaskInfoFetcher(taskId, new TestingResponseManager(taskId.toString()));
+ assertFalse(taskInfoFetcher.getTaskInfo().isPresent());
+ taskInfoFetcher.start();
+ try {
+ Thread.sleep(3 * fetchInterval.toMillis());
+ }
+ catch (InterruptedException e) {
+ e.printStackTrace();
+ fail();
+ }
+ assertTrue(taskInfoFetcher.getTaskInfo().isPresent());
+ }
+
+ @Test
+ public void testInfoFetcherWithRetry()
+ {
+ TaskId taskId = new TaskId("testid", 0, 0, 0, 0);
+
+ Duration fetchInterval = new Duration(1, TimeUnit.SECONDS);
+ HttpNativeExecutionTaskInfoFetcher taskInfoFetcher = createTaskInfoFetcher(
+ taskId,
+ new TestingResponseManager(taskId.toString(), new FailureTaskInfoRetryResponseManager(1)),
+ new Duration(5, TimeUnit.SECONDS),
+ new Object());
+ assertFalse(taskInfoFetcher.getTaskInfo().isPresent());
+ taskInfoFetcher.start();
+ try {
+ Thread.sleep(3 * fetchInterval.toMillis());
+ }
+ catch (InterruptedException e) {
+ e.printStackTrace();
+ fail();
+ }
+
+ // First fetch is expected to succeed.
+ assertTrue(taskInfoFetcher.getTaskInfo().isPresent());
+
+ try {
+ Thread.sleep(10 * fetchInterval.toMillis());
+ }
+ catch (InterruptedException e) {
+ e.printStackTrace();
+ fail();
+ }
+ Exception exception = expectThrows(RuntimeException.class, taskInfoFetcher::getTaskInfo);
+ assertThat(exception.getMessage())
+ .contains("getTaskInfo encountered too many errors talking to native process");
+ }
+
+ @Test(timeOut = 60 * 1000)
+ public void testInfoFetcherUnexpectedResponse()
+ throws InterruptedException
+ {
+ TaskId taskId = new TaskId("testid", 0, 0, 0, 0);
+ Object monitor = new Object();
+ HttpNativeExecutionTaskInfoFetcher taskInfoFetcher = createTaskInfoFetcher(
+ taskId,
+ new TestingResponseManager(taskId.toString(), new UnexpectedResponseTaskInfoRetryResponseManager()),
+ new Duration(5, TimeUnit.SECONDS),
+ monitor);
+ taskInfoFetcher.start();
+ synchronized (monitor) {
+ while (taskInfoFetcher.getLastException().get() == null && !taskInfoFetcher.getTaskInfo().isPresent()) {
+ monitor.wait();
+ }
+ }
+ assertThatThrownBy(taskInfoFetcher::getTaskInfo)
+ .isInstanceOf(PrestoException.class)
+ .hasMessageContaining("500");
+ }
+
+ @Test
+ public void testInfoFetcherWaitOnSignal()
+ {
+ TaskId taskId = new TaskId("testid", 0, 0, 0, 0);
+ Object lock = new Object();
+
+ HttpNativeExecutionTaskInfoFetcher taskInfoFetcher = createTaskInfoFetcher(taskId, new TestingResponseManager(taskId.toString(), TaskState.FINISHED), lock);
+ assertFalse(taskInfoFetcher.getTaskInfo().isPresent());
+ taskInfoFetcher.start();
+ try {
+ synchronized (lock) {
+ while (!isTaskDone(taskInfoFetcher.getTaskInfo())) {
+ lock.wait();
+ }
+ }
+ }
+ catch (InterruptedException e) {
+ fail();
+ }
+ assertTrue(isTaskDone(taskInfoFetcher.getTaskInfo()));
+ }
+
+ private boolean isTaskDone(Optional taskInfo)
+ {
+ return taskInfo.isPresent() && taskInfo.get().getTaskStatus().getState().isDone();
+ }
+
+ @Test
+ public void testNativeExecutionTask()
+ {
+ // We need multi-thread scheduler to increase scheduling concurrency.
+ // Otherwise, async execution assumption is not going to hold with a
+ // single thread.
+ TaskId taskId = new TaskId("testid", 0, 0, 0, 0);
+ TaskManagerConfig taskConfig = new TaskManagerConfig();
+ QueryManagerConfig queryConfig = new QueryManagerConfig();
+ taskConfig.setInfoRefreshMaxWait(new Duration(5, TimeUnit.SECONDS));
+ taskConfig.setInfoUpdateInterval(new Duration(200, TimeUnit.MILLISECONDS));
+ queryConfig.setRemoteTaskMaxErrorDuration(new Duration(1, TimeUnit.MINUTES));
+ List sources = new ArrayList<>();
+ try {
+ NativeExecutionTaskFactory taskFactory = new NativeExecutionTaskFactory(
+ new TestingHttpClient(
+ scheduledExecutorService,
+ new TestingResponseManager(taskId.toString(), new TimeoutResponseManager(0, 10, 0))),
+ scheduledExecutorService,
+ scheduledExecutorService,
+ TASK_INFO_JSON_CODEC,
+ PLAN_FRAGMENT_JSON_CODEC,
+ TASK_UPDATE_REQUEST_JSON_CODEC,
+ taskConfig,
+ queryConfig);
+ NativeExecutionTask task = taskFactory.createNativeExecutionTask(
+ testSessionBuilder().build(),
+ BASE_URI,
+ taskId,
+ createPlanFragment(),
+ sources,
+ new TableWriteInfo(Optional.empty(), Optional.empty()),
+ Optional.empty(),
+ Optional.empty());
+ assertNotNull(task);
+ assertFalse(task.getTaskInfo().isPresent());
+ assertFalse(task.pollResult().isPresent());
+
+ // Start task
+ TaskInfo taskInfo = task.start();
+ assertFalse(taskInfo.getTaskStatus().getState().isDone());
+
+ List resultPages = new ArrayList<>();
+ for (int i = 0; i < 100 && resultPages.size() < 10; ++i) {
+ Optional page = task.pollResult();
+ page.ifPresent(resultPages::add);
+ }
+ assertFalse(task.pollResult().isPresent());
+ assertEquals(10, resultPages.size());
+ assertTrue(task.getTaskInfo().isPresent());
+
+ task.stop(true);
+ }
+ catch (InterruptedException e) {
+ e.printStackTrace();
+ fail();
+ }
+ }
+
+ private NativeExecutionProcess createNativeExecutionProcess(
+ Duration maxErrorDuration,
+ TestingResponseManager responseManager)
+ {
+ PrestoSparkWorkerProperty workerProperty = new PrestoSparkWorkerProperty(
+ new NativeExecutionConnectorConfig(),
+ new NativeExecutionNodeConfig(),
+ new NativeExecutionSystemConfig(),
+ new NativeExecutionVeloxConfig());
+ NativeExecutionProcessFactory factory = new NativeExecutionProcessFactory(
+ new TestingHttpClient(scheduledExecutorService, responseManager),
+ scheduledExecutorService,
+ scheduledExecutorService,
+ SERVER_INFO_JSON_CODEC,
+ workerProperty,
+ new FeaturesConfig());
+ return factory.createNativeExecutionProcess(testSessionBuilder().build(), maxErrorDuration);
+ }
+
+ private HttpNativeExecutionTaskInfoFetcher createTaskInfoFetcher(TaskId taskId, TestingResponseManager testingResponseManager)
+ {
+ return createTaskInfoFetcher(taskId, testingResponseManager, new Duration(1, TimeUnit.MINUTES), new Object());
+ }
+
+ private HttpNativeExecutionTaskInfoFetcher createTaskInfoFetcher(TaskId taskId, TestingResponseManager testingResponseManager, Object lock)
+ {
+ return createTaskInfoFetcher(taskId, testingResponseManager, new Duration(1, TimeUnit.MINUTES), lock);
+ }
+
+ private HttpNativeExecutionTaskInfoFetcher createTaskInfoFetcher(TaskId taskId, TestingResponseManager testingResponseManager, Duration maxErrorDuration, Object lock)
+ {
+ PrestoSparkHttpTaskClient workerClient = createWorkerClient(taskId, new TestingHttpClient(scheduledExecutorService, testingResponseManager));
+ return new HttpNativeExecutionTaskInfoFetcher(
+ scheduledExecutorService,
+ workerClient,
+ new Duration(1, TimeUnit.SECONDS),
+ lock);
+ }
+
+ private static class TestingHttpResponseFuture
+ extends AbstractFuture
+ implements HttpClient.HttpResponseFuture
+ {
+ @Override
+ public String getState()
+ {
+ return null;
+ }
+
+ public void complete(T value)
+ {
+ super.set(value);
+ }
+
+ public void completeExceptionally(Throwable t)
+ {
+ super.setException(t);
+ }
+ }
+
+ public static class TestingHttpClient
+ implements com.facebook.airlift.http.client.HttpClient
+ {
+ private static final String TASK_ID_REGEX = "/v1/task/[a-zA-Z0-9]+.[0-9]+.[0-9]+.[0-9]+.[0-9]+";
+ private final ScheduledExecutorService executor;
+ private final TestingResponseManager responseManager;
+
+ public TestingHttpClient(ScheduledExecutorService executor, TestingResponseManager responseManager)
+ {
+ this.executor = executor;
+ this.responseManager = responseManager;
+ }
+
+ @Override
+ public T execute(Request request, ResponseHandler responseHandler)
+ throws E
+ {
+ try {
+ return executeAsync(request, responseHandler).get();
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ return null;
+ }
+ }
+
+ @Override
+ public HttpResponseFuture executeAsync(Request request, ResponseHandler responseHandler)
+ {
+ TestingHttpResponseFuture future = new TestingHttpResponseFuture<>();
+ executor.schedule(
+ () ->
+ {
+ URI uri = request.getUri();
+ String method = request.getMethod();
+ String path = uri.getPath();
+ try {
+ if (method.equalsIgnoreCase("GET")) {
+ // GET /v1/task/{taskId}
+ if (Pattern.compile(TASK_ID_REGEX + "\\z").matcher(path).find()) {
+ future.complete(responseHandler.handle(request, responseManager.createTaskInfoResponse(HttpStatus.OK)));
+ }
+ // GET /v1/task/{taskId}/results/{bufferId}/{token}/acknowledge
+ else if (Pattern.compile(".*/results/[0-9]+/[0-9]+/acknowledge\\z").matcher(path).find()) {
+ future.complete(responseHandler.handle(request, responseManager.createDummyResultResponse()));
+ }
+ // GET /v1/task/{taskId}/results/{bufferId}/{token}
+ else if (Pattern.compile(".*/results/[0-9]+/[0-9]+\\z").matcher(path).find()) {
+ future.complete(responseHandler.handle(
+ request,
+ responseManager.createResultResponse()));
+ }
+ // GET /v1/info
+ else if (Pattern.compile("/v1/info").matcher(path).find()) {
+ future.complete(responseHandler.handle(
+ request,
+ responseManager.createServerInfoResponse()));
+ }
+ }
+ else if (method.equalsIgnoreCase("POST")) {
+ // POST /v1/task/{taskId}/batch
+ if (Pattern.compile(format("%s\\/batch\\z", TASK_ID_REGEX)).matcher(path).find()) {
+ future.complete(responseHandler.handle(request, responseManager.createTaskInfoResponse(HttpStatus.OK)));
+ }
+ }
+ else if (method.equalsIgnoreCase("DELETE")) {
+ // DELETE /v1/task/{taskId}/results/{bufferId}
+ if (Pattern.compile(format("%s\\/results\\/[0-9]+\\z", TASK_ID_REGEX)).matcher(path).find()) {
+ future.complete(responseHandler.handle(request, responseManager.createDummyResultResponse()));
+ }
+ // DELETE /v1/task/{taskId}
+ else if (Pattern.compile(TASK_ID_REGEX + "\\z").matcher(path).find()) {
+ future.complete(responseHandler.handle(request, responseManager.createDummyResultResponse()));
+ }
+ }
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ future.completeExceptionally(e);
+ }
+
+ if (!future.isDone()) {
+ future.completeExceptionally(new Exception(format("Unsupported request: %s %s", method, path)));
+ }
+ },
+ (long) NO_DURATION.getValue(),
+ NO_DURATION.getUnit());
+ return future;
+ }
+
+ @Override
+ public RequestStats getStats()
+ {
+ return null;
+ }
+
+ @Override
+ public long getMaxContentLength()
+ {
+ return 0;
+ }
+
+ @Override
+ public void close()
+ {
+ }
+
+ @Override
+ public boolean isClosed()
+ {
+ return false;
+ }
+ }
+
+ /**
+ * A stateful response manager for testing purpose. The lifetime of an instantiation of this class should be equivalent to the lifetime of the http client.
+ */
+ public static class TestingResponseManager
+ {
+ private static final JsonCodec taskInfoCodec = JsonCodec.jsonCodec(TaskInfo.class);
+ private static final JsonCodec serverInfoCodec = JsonCodec.jsonCodec(ServerInfo.class);
+ private final TestingResultResponseManager resultResponseManager;
+ private final TestingServerResponseManager serverResponseManager;
+ private final TestingTaskInfoResponseManager taskInfoResponseManager;
+ private final String taskId;
+
+ public TestingResponseManager(String taskId)
+ {
+ this.taskId = requireNonNull(taskId, "taskId is null");
+ this.resultResponseManager = new TestingResultResponseManager();
+ this.serverResponseManager = new TestingServerResponseManager();
+ this.taskInfoResponseManager = new TestingTaskInfoResponseManager();
+ }
+
+ public TestingResponseManager(String taskId, TaskState taskState)
+ {
+ this.taskId = requireNonNull(taskId, "taskId is null");
+ this.resultResponseManager = new TestingResultResponseManager();
+ this.serverResponseManager = new TestingServerResponseManager();
+ this.taskInfoResponseManager = new TestingTaskInfoResponseManager(taskState);
+ }
+
+ public TestingResponseManager(String taskId, TestingResultResponseManager resultResponseManager)
+ {
+ this.taskId = requireNonNull(taskId, "taskId is null");
+ this.resultResponseManager = requireNonNull(resultResponseManager, "resultResponseManager is null.");
+ this.serverResponseManager = new TestingServerResponseManager();
+ this.taskInfoResponseManager = new TestingTaskInfoResponseManager();
+ }
+
+ public TestingResponseManager(String taskId, TestingServerResponseManager serverResponseManager)
+ {
+ this.taskId = requireNonNull(taskId, "taskId is null");
+ this.resultResponseManager = new TestingResultResponseManager();
+ this.taskInfoResponseManager = new TestingTaskInfoResponseManager();
+ this.serverResponseManager = requireNonNull(serverResponseManager, "serverResponseManager is null");
+ }
+
+ public TestingResponseManager(String taskId, TestingTaskInfoResponseManager taskInfoResponseManager)
+ {
+ this.taskId = requireNonNull(taskId, "taskId is null");
+ this.resultResponseManager = new TestingResultResponseManager();
+ this.serverResponseManager = new TestingServerResponseManager();
+ this.taskInfoResponseManager = requireNonNull(taskInfoResponseManager, "taskInfoResponseManager is null");
+ }
+
+ public Response createDummyResultResponse()
+ {
+ return new TestingResponse();
+ }
+
+ public Response createResultResponse()
+ throws PageTransportErrorException
+ {
+ return resultResponseManager.createResultResponse(taskId);
+ }
+
+ public Response createServerInfoResponse()
+ throws PrestoException
+ {
+ return serverResponseManager.createServerInfoResponse();
+ }
+
+ public Response createTaskInfoResponse(HttpStatus httpStatus)
+ throws PrestoException
+ {
+ return taskInfoResponseManager.createTaskInfoResponse(httpStatus, taskId);
+ }
+
+ /**
+ * Manager for server related endpoints. It maintains any stateful information inside itself. Callers can extend this class to create their own response handling
+ * logic.
+ */
+ public static class TestingServerResponseManager
+ {
+ public Response createServerInfoResponse()
+ throws PrestoException
+ {
+ ServerInfo serverInfo = new ServerInfo(UNKNOWN, "test", true, false, Optional.of(Duration.valueOf("2m")));
+ HttpStatus httpStatus = HttpStatus.OK;
+ ListMultimap headers = ArrayListMultimap.create();
+ headers.put(HeaderName.of(CONTENT_TYPE), String.valueOf(MediaType.create("application", "json")));
+ return new TestingResponse(
+ httpStatus.code(),
+ headers,
+ new ByteArrayInputStream(serverInfoCodec.toBytes(serverInfo)));
+ }
+ }
+
+ /**
+ * Manager for result fetching related endpoints. It maintains any stateful information inside itself. Callers can extend this class to create their own response handling
+ * logic.
+ */
+ public static class TestingResultResponseManager
+ {
+ /**
+ * A dummy implementation of result creation logic. It shall be overriden by users to create customized result returning logic.
+ */
+ public Response createResultResponse(String taskId)
+ throws PageTransportErrorException
+ {
+ return createResultResponseHelper(HttpStatus.OK,
+ taskId,
+ 0,
+ 1,
+ true,
+ 0);
+ }
+
+ protected Response createResultResponseHelper(
+ HttpStatus httpStatus,
+ String taskId,
+ long token,
+ long nextToken,
+ boolean bufferComplete,
+ int serializedPageSizeBytes)
+ {
+ DynamicSliceOutput slicedOutput = new DynamicSliceOutput(1024);
+ PagesSerdeUtil.writeSerializedPage(slicedOutput, createSerializedPage(serializedPageSizeBytes));
+ ListMultimap headers = ArrayListMultimap.create();
+ headers.put(HeaderName.of(PRESTO_PAGE_TOKEN), String.valueOf(token));
+ headers.put(HeaderName.of(PRESTO_PAGE_NEXT_TOKEN), String.valueOf(nextToken));
+ headers.put(HeaderName.of(PRESTO_BUFFER_COMPLETE), String.valueOf(bufferComplete));
+ headers.put(HeaderName.of(PRESTO_TASK_INSTANCE_ID), taskId);
+ headers.put(HeaderName.of(CONTENT_TYPE), PRESTO_PAGES_TYPE.toString());
+ return new TestingResponse(
+ httpStatus.code(),
+ headers,
+ slicedOutput.slice().getInput());
+ }
+ }
+
+ /**
+ * Manager for taskInfo fetching related endpoints. It maintains any stateful information inside itself. Callers can extend this class to create their own response handling
+ * logic.
+ */
+ public static class TestingTaskInfoResponseManager
+ {
+ private final TaskState taskState;
+
+ public TestingTaskInfoResponseManager()
+ {
+ taskState = TaskState.PLANNED;
+ }
+
+ public TestingTaskInfoResponseManager(TaskState taskState)
+ {
+ this.taskState = taskState;
+ }
+
+ public Response createTaskInfoResponse(HttpStatus httpStatus, String taskId)
+ throws PrestoException
+ {
+ URI location = uriBuilderFrom(BASE_URI).appendPath(TASK_ROOT_PATH).build();
+ ListMultimap headers = ArrayListMultimap.create();
+ headers.put(HeaderName.of(CONTENT_TYPE), String.valueOf(MediaType.create("application", "json")));
+ TaskInfo taskInfo = TaskInfo.createInitialTask(
+ TaskId.valueOf(taskId),
+ location,
+ new ArrayList<>(),
+ new TaskStats(System.currentTimeMillis(), 0L),
+ "dummy-node").withTaskStatus(createTaskStatusDone(location));
+ return new TestingResponse(
+ httpStatus.code(),
+ headers,
+ new ByteArrayInputStream(taskInfoCodec.toBytes(taskInfo)));
+ }
+
+ private TaskStatus createTaskStatusDone(URI location)
+ {
+ return new TaskStatus(
+ 0L,
+ 0L,
+ 0,
+ taskState,
+ location,
+ ImmutableSet.of(),
+ ImmutableList.of(),
+ 0,
+ 0,
+ 0.0,
+ false,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0L,
+ 0L);
+ }
+ }
+
+ public static class CrashingTaskInfoResponseManager
+ extends TestingResponseManager.TestingTaskInfoResponseManager
+ {
+ public CrashingTaskInfoResponseManager()
+ {
+ super();
+ }
+
+ @Override
+ public Response createTaskInfoResponse(HttpStatus httpStatus, String taskId)
+ throws PrestoException
+ {
+ throw new RuntimeException("Server refused connection");
+ }
+ }
+ }
+
+ public static class TestingResponse
+ implements Response
+ {
+ private final int statusCode;
+ private final ListMultimap headers;
+ private InputStream inputStream;
+
+ private TestingResponse()
+ {
+ this.statusCode = HttpStatus.OK.code();
+ this.headers = ArrayListMultimap.create();
+ }
+
+ private TestingResponse(
+ int statusCode,
+ ListMultimap headers,
+ InputStream inputStream)
+ {
+ this.statusCode = statusCode;
+ this.headers = headers;
+ this.inputStream = inputStream;
+ }
+
+ @Override
+ public int getStatusCode()
+ {
+ return statusCode;
+ }
+
+ @Override
+ public ListMultimap getHeaders()
+ {
+ return headers;
+ }
+
+ @Override
+ public long getBytesRead()
+ {
+ return 0;
+ }
+
+ @Override
+ public InputStream getInputStream()
+ {
+ return inputStream;
+ }
+ }
+
+ private static SerializedPage createSerializedPage(int numBytes)
+ {
+ byte[] bytes = new byte[numBytes];
+ Arrays.fill(bytes, (byte) 8);
+ Slice slice = Slices.wrappedBuffer(bytes);
+ return new SerializedPage(
+ slice,
+ PageCodecMarker.none(),
+ 0,
+ numBytes,
+ 0);
+ }
+
+ public static class FailureRetryResponseManager
+ extends TestingResponseManager.TestingServerResponseManager
+ {
+ private final int maxRetryCount;
+ private int retryCount;
+
+ public FailureRetryResponseManager(int maxRetryCount)
+ {
+ this.maxRetryCount = maxRetryCount;
+ }
+
+ @Override
+ public Response createServerInfoResponse()
+ throws PrestoException
+ {
+ if (retryCount++ < maxRetryCount) {
+ throw new RuntimeException("Get ServerInfo request failure.");
+ }
+
+ return super.createServerInfoResponse();
+ }
+ }
+
+ public static class FailureRetryTaskInfoResponseManager
+ extends TestingResponseManager.TestingTaskInfoResponseManager
+ {
+ private final int maxRetryCount;
+ private int retryCount;
+
+ public FailureRetryTaskInfoResponseManager(int maxRetryCount)
+ {
+ this.maxRetryCount = maxRetryCount;
+ }
+
+ @Override
+ public Response createTaskInfoResponse(HttpStatus httpStatus, String taskId)
+ throws PrestoException
+ {
+ if (retryCount++ < maxRetryCount) {
+ throw new RuntimeException("retriable failure");
+ }
+
+ return super.createTaskInfoResponse(httpStatus, taskId);
+ }
+ }
+
+ private static class FailureTaskInfoRetryResponseManager
+ extends TestingResponseManager.TestingTaskInfoResponseManager
+ {
+ private final int failureCount;
+ private int retryCount;
+
+ public FailureTaskInfoRetryResponseManager(int failureCount)
+ {
+ super();
+ this.failureCount = failureCount;
+ }
+
+ @Override
+ public Response createTaskInfoResponse(HttpStatus httpStatus, String taskId)
+ throws PrestoException
+ {
+ if (retryCount++ > failureCount) {
+ throw new RuntimeException("retriable failure");
+ }
+
+ return super.createTaskInfoResponse(httpStatus, taskId);
+ }
+ }
+
+ private static class UnexpectedResponseTaskInfoRetryResponseManager
+ extends TestingResponseManager.TestingTaskInfoResponseManager
+ {
+ private int requestCount;
+
+ @Override
+ public Response createTaskInfoResponse(HttpStatus httpStatus, String taskId)
+ throws PrestoException
+ {
+ if (requestCount == 0) {
+ requestCount++;
+ return super.createTaskInfoResponse(HttpStatus.INTERNAL_SERVER_ERROR, taskId);
+ }
+ throw new RuntimeException("response handler is not expected to be called more than once");
+ }
+ }
+}
diff --git a/presto-spark-base/src/test/java/com/facebook/presto/spark/execution/nativeprocess/TestHttpNativeExecutionTaskInfoFetcher.java b/presto-spark-base/src/test/java/com/facebook/presto/spark/execution/nativeprocess/TestHttpNativeExecutionTaskInfoFetcher.java
new file mode 100644
index 0000000000000..68acb4bdb50e5
--- /dev/null
+++ b/presto-spark-base/src/test/java/com/facebook/presto/spark/execution/nativeprocess/TestHttpNativeExecutionTaskInfoFetcher.java
@@ -0,0 +1,115 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.spark.execution.nativeprocess;
+
+import com.facebook.airlift.json.JsonCodec;
+import com.facebook.presto.execution.TaskId;
+import com.facebook.presto.execution.TaskInfo;
+import com.facebook.presto.spark.execution.http.BatchTaskUpdateRequest;
+import com.facebook.presto.spark.execution.http.PrestoSparkHttpTaskClient;
+import com.facebook.presto.spark.execution.http.TestPrestoSparkHttpClient;
+import com.facebook.presto.sql.planner.PlanFragment;
+import io.airlift.units.Duration;
+import org.testng.annotations.Test;
+
+import java.net.URI;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+import static java.util.concurrent.Executors.newScheduledThreadPool;
+import static org.testng.Assert.assertTrue;
+import static org.testng.Assert.fail;
+
+public class TestHttpNativeExecutionTaskInfoFetcher
+{
+ private static final URI BASE_URI = URI.create("http://localhost");
+ private static final TaskId TEST_TASK_ID = TaskId.valueOf("test.0.0.0.0");
+ private static final JsonCodec TASK_INFO_JSON_CODEC = JsonCodec.jsonCodec(TaskInfo.class);
+ private static final JsonCodec PLAN_FRAGMENT_JSON_CODEC = JsonCodec.jsonCodec(PlanFragment.class);
+ private static final JsonCodec TASK_UPDATE_REQUEST_JSON_CODEC = JsonCodec.jsonCodec(BatchTaskUpdateRequest.class);
+ private static final ScheduledExecutorService updateScheduledExecutor = newScheduledThreadPool(4);
+
+ @Test
+ public void testNativeExecutionTaskFailsWhenProcessCrashes()
+ {
+ PrestoSparkHttpTaskClient workerClient = new PrestoSparkHttpTaskClient(
+ new TestPrestoSparkHttpClient.TestingHttpClient(
+ updateScheduledExecutor,
+ new TestPrestoSparkHttpClient.TestingResponseManager(
+ TEST_TASK_ID.toString(),
+ new TestPrestoSparkHttpClient.TestingResponseManager.CrashingTaskInfoResponseManager())),
+ TEST_TASK_ID,
+ BASE_URI,
+ TASK_INFO_JSON_CODEC,
+ PLAN_FRAGMENT_JSON_CODEC,
+ TASK_UPDATE_REQUEST_JSON_CODEC,
+ // very low tolerance for error for unit testing
+ new Duration(1, TimeUnit.MILLISECONDS),
+ updateScheduledExecutor,
+ updateScheduledExecutor,
+ new Duration(1, TimeUnit.SECONDS));
+
+ Object taskFinishedOrLostSignal = new Object();
+
+ HttpNativeExecutionTaskInfoFetcher taskInfoFetcher = new HttpNativeExecutionTaskInfoFetcher(
+ updateScheduledExecutor,
+ workerClient,
+ new Duration(1, TimeUnit.SECONDS),
+ taskFinishedOrLostSignal);
+
+ // set up a listener for the notification
+ AtomicBoolean notifyCalled = new AtomicBoolean(false);
+
+ // As there is no easy way to test that notify was called,
+ // we use this test object as a way to capture a side effect
+ // of notify being called
+ Object testSignallingObject = new Object();
+
+ Thread observerThread = new Thread(() -> {
+ try {
+ synchronized (taskFinishedOrLostSignal) {
+ while (!Thread.interrupted() && taskInfoFetcher.getLastException().get() != null) {
+ taskFinishedOrLostSignal.wait();
+ }
+ notifyCalled.set(true);
+ synchronized (testSignallingObject) {
+ testSignallingObject.notifyAll();
+ }
+ }
+ }
+ catch (InterruptedException ex) {
+ fail("Error in test observer thread waiting for notification from info fetcher");
+ }
+ });
+ observerThread.start();
+
+ taskInfoFetcher.doGetTaskInfo();
+
+ //test that notify was called
+ try {
+ synchronized (testSignallingObject) {
+ while (!notifyCalled.get()) {
+ testSignallingObject.wait();
+ }
+ }
+ }
+ catch (InterruptedException ex) {
+ fail("Exception while waiting for info fetcher to signal process crash", ex);
+ }
+ observerThread.interrupt();
+
+ assertTrue(notifyCalled.get());
+ }
+}
diff --git a/presto-spark-classloader-interface/pom.xml b/presto-spark-classloader-interface/pom.xml
index eecde0fa81386..92753ae735f28 100644
--- a/presto-spark-classloader-interface/pom.xml
+++ b/presto-spark-classloader-interface/pom.xml
@@ -11,12 +11,25 @@
${project.parent.basedir}
+ true
- com.facebook.presto.spark
- spark-core
+ org.apache.spark
+ spark-core_2.13
+ provided
+
+
+ com.esotericsoftware
+ kryo-shaded
+ 4.0.2
+ provided
+
+
+ org.scala-lang
+ scala-library
+ 2.13.8
provided
diff --git a/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/IPrestoSparkService.java b/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/IPrestoSparkService.java
index 57dc25eb02232..d64e3da7fd759 100644
--- a/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/IPrestoSparkService.java
+++ b/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/IPrestoSparkService.java
@@ -20,5 +20,7 @@ public interface IPrestoSparkService
IPrestoSparkTaskExecutorFactory getTaskExecutorFactory();
+ IPrestoSparkTaskExecutorFactory getNativeTaskExecutorFactory();
+
void close();
}
diff --git a/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkNativeExecutionShuffleManager.java b/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkNativeExecutionShuffleManager.java
index 156292e96e9a2..8678c06b46c63 100644
--- a/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkNativeExecutionShuffleManager.java
+++ b/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkNativeExecutionShuffleManager.java
@@ -23,7 +23,9 @@
import org.apache.spark.shuffle.ShuffleBlockResolver;
import org.apache.spark.shuffle.ShuffleHandle;
import org.apache.spark.shuffle.ShuffleManager;
+import org.apache.spark.shuffle.ShuffleReadMetricsReporter;
import org.apache.spark.shuffle.ShuffleReader;
+import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.shuffle.sort.BypassMergeSortShuffleHandle;
import org.apache.spark.storage.BlockManager;
@@ -77,26 +79,26 @@ private static T instantiateClass(String className, SparkConf conf)
}
}
- protected void registerShuffleHandle(BaseShuffleHandle handle, int stageId, int mapId)
+ protected void registerShuffleHandle(BaseShuffleHandle handle, int stageId, long mapId)
{
partitionIdToShuffleHandle.put(new StageAndMapId(stageId, mapId), handle);
shuffleIdToBaseShuffleHandle.put(handle.shuffleId(), handle);
}
- protected void unregisterShuffleHandle(int shuffleId, int stageId, int mapId)
+ protected void unregisterShuffleHandle(int shuffleId, int stageId, long mapId)
{
partitionIdToShuffleHandle.remove(new StageAndMapId(stageId, mapId));
shuffleIdToBaseShuffleHandle.remove(shuffleId);
}
@Override
- public ShuffleHandle registerShuffle(int shuffleId, int numMaps, ShuffleDependency dependency)
+ public ShuffleHandle registerShuffle(int shuffleId, ShuffleDependency dependency)
{
- return fallbackShuffleManager.registerShuffle(shuffleId, numMaps, dependency);
+ return fallbackShuffleManager.registerShuffle(shuffleId, dependency);
}
@Override
- public ShuffleWriter getWriter(ShuffleHandle handle, int mapId, TaskContext context)
+ public ShuffleWriter getWriter(ShuffleHandle handle, long mapId, TaskContext context, ShuffleWriteMetricsReporter metrics)
{
checkState(
requireNonNull(handle, "handle is null") instanceof BypassMergeSortShuffleHandle,
@@ -111,7 +113,13 @@ public ShuffleWriter getWriter(ShuffleHandle handle, int mapId, Tas
}
@Override
- public ShuffleReader getReader(ShuffleHandle handle, int startPartition, int endPartition, TaskContext context)
+ public ShuffleReader getReader(ShuffleHandle handle, int startPartition, int endPartition, TaskContext context, ShuffleReadMetricsReporter metrics)
+ {
+ return new EmptyShuffleReader<>();
+ }
+
+ @Override
+ public ShuffleReader getReader(ShuffleHandle handle, int startMapIndex, int endMapIndex, int startPartition, int endPartition, TaskContext context, ShuffleReadMetricsReporter metrics)
{
return new EmptyShuffleReader<>();
}
@@ -196,16 +204,23 @@ public Option stop(boolean success)
{
onStop.run();
BlockManager blockManager = SparkEnv.get().blockManager();
- return Option.apply(MapStatus$.MODULE$.apply(blockManager.blockManagerId(), mapStatus));
+ return Option.apply(
+ MapStatus$.MODULE$.apply(blockManager.blockManagerId(), mapStatus, 0L));
+ }
+
+ @Override
+ public long[] getPartitionLengths()
+ {
+ return mapStatus;
}
}
public static class StageAndMapId
{
private final int stageId;
- private final int mapId;
+ private final long mapId;
- public StageAndMapId(int stageId, int mapId)
+ public StageAndMapId(int stageId, long mapId)
{
this.stageId = stageId;
this.mapId = mapId;
@@ -216,7 +231,7 @@ public int getStageId()
return stageId;
}
- public int getMapId()
+ public long getMapId()
{
return mapId;
}
diff --git a/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkNativeTaskRdd.java b/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkNativeTaskRdd.java
new file mode 100644
index 0000000000000..a524767a88ead
--- /dev/null
+++ b/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/PrestoSparkNativeTaskRdd.java
@@ -0,0 +1,207 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.spark.classloader_interface;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import org.apache.spark.MapOutputTracker;
+import org.apache.spark.Partition;
+import org.apache.spark.ShuffleDependency;
+import org.apache.spark.SparkContext;
+import org.apache.spark.SparkEnv;
+import org.apache.spark.TaskContext;
+import org.apache.spark.rdd.RDD;
+import org.apache.spark.rdd.ShuffledRDD;
+import org.apache.spark.rdd.ShuffledRDDPartition;
+import org.apache.spark.rdd.ZippedPartitionsPartition;
+import org.apache.spark.shuffle.ShuffleHandle;
+import org.apache.spark.storage.BlockId;
+import org.apache.spark.storage.BlockManagerId;
+import scala.Tuple2;
+import scala.collection.Iterator;
+import scala.collection.Seq;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.stream.Collectors;
+
+import static com.facebook.presto.spark.classloader_interface.ScalaUtils.emptyScalaIterator;
+import static com.google.common.base.Preconditions.checkState;
+import static java.lang.String.format;
+import static java.util.Objects.requireNonNull;
+import static scala.collection.JavaConversions.asJavaCollection;
+import static scala.collection.JavaConversions.seqAsJavaList;
+
+/**
+ * PrestoSparkTaskRdd represents execution of Presto stage, it contains:
+ * - A list of shuffleInputRdds, each of the corresponding to a child stage.
+ * - An optional taskSourceRdd, which represents ALL table scan inputs in this stage.
+ *
+ * Table scan is present when joining a bucketed table with an unbucketed table, for example:
+ * Join
+ * / \
+ * Scan Remote Source
+ *
+ * In this case, bucket to Spark partition mapping has to be consistent with the Spark shuffle partition.
+ *
+ * When the stage partitioning is SINGLE_DISTRIBUTION and the shuffleInputRdds is empty,
+ * the taskSourceRdd is expected to be present and contain exactly one empty partition.
+ *
+ * The broadcast inputs are encapsulated in taskProcessor.
+ */
+public class PrestoSparkNativeTaskRdd
+ extends PrestoSparkTaskRdd
+{
+ public static PrestoSparkNativeTaskRdd create(
+ SparkContext context,
+ Optional taskSourceRdd,
+ // fragmentId -> RDD
+ Map>> shuffleInputRddMap,
+ PrestoSparkTaskProcessor taskProcessor)
+ {
+ requireNonNull(context, "context is null");
+ requireNonNull(taskSourceRdd, "taskSourceRdd is null");
+ requireNonNull(shuffleInputRddMap, "shuffleInputRddMap is null");
+ requireNonNull(taskProcessor, "taskProcessor is null");
+ ImmutableList.Builder shuffleInputFragmentIds = ImmutableList.builder();
+ ImmutableList.Builder>> shuffleInputRdds = ImmutableList.builder();
+ for (Map.Entry>> entry : shuffleInputRddMap.entrySet()) {
+ shuffleInputFragmentIds.add(entry.getKey());
+ shuffleInputRdds.add(entry.getValue());
+ }
+ return new PrestoSparkNativeTaskRdd<>(context, taskSourceRdd, shuffleInputFragmentIds.build(), shuffleInputRdds.build(), taskProcessor);
+ }
+
+ @Override
+ public Iterator> compute(Partition split, TaskContext context)
+ {
+ PrestoSparkTaskSourceRdd taskSourceRdd = getTaskSourceRdd();
+ List partitions = seqAsJavaList(((ZippedPartitionsPartition) split).partitions());
+ int expectedPartitionsSize = (taskSourceRdd != null ? 1 : 0) + getShuffleInputRdds().size();
+ checkState(partitions.size() == expectedPartitionsSize,
+ format("Unexpected partitions size. Expected: %s. Actual: %s.", expectedPartitionsSize, partitions.size()));
+
+ Iterator taskSourceIterator;
+ if (taskSourceRdd != null) {
+ taskSourceIterator = taskSourceRdd.iterator(partitions.get(partitions.size() - 1), context);
+ }
+ else {
+ taskSourceIterator = emptyScalaIterator();
+ }
+
+ return getTaskProcessor().process(
+ taskSourceIterator,
+ getShuffleReadDescriptors(partitions),
+ getShuffleWriteDescriptor(context.stageId(), split));
+ }
+
+ private PrestoSparkNativeTaskRdd(
+ SparkContext context,
+ Optional taskSourceRdd,
+ List