diff --git a/.github/workflows/arrow-flight-tests.yml b/.github/workflows/arrow-flight-tests.yml index 3c9582c790bbe..e645ef03b17cf 100644 --- a/.github/workflows/arrow-flight-tests.yml +++ b/.github/workflows/arrow-flight-tests.yml @@ -14,7 +14,7 @@ env: RETRY: .github/bin/retry jobs: - test: + arrowflight-java-tests: runs-on: ubuntu-latest strategy: fail-fast: false @@ -48,6 +48,161 @@ jobs: export MAVEN_OPTS="${MAVEN_INSTALL_OPTS}" ./mvnw install ${MAVEN_FAST_INSTALL} -am -pl ${{ matrix.modules }} - # Run Maven tests for the target module + # Run Maven tests for the target module, excluding native tests - name: Maven Tests - run: ./mvnw test ${MAVEN_TEST} -pl ${{ matrix.modules }} + run: ./mvnw test ${MAVEN_TEST} -pl ${{ matrix.modules }} -Dtest="*,!TestArrowFlightNativeQueries" + + prestocpp-linux-build-for-test: + runs-on: ubuntu-22.04 + container: + image: prestodb/presto-native-dependency:0.292-20250204112033-cf8ba84 + env: + CCACHE_DIR: "${{ github.workspace }}/ccache" + DEPENDENCY_DIR: "${{ github.workspace }}/adapter-deps/download" + INSTALL_PREFIX: "${{ github.workspace }}/adapter-deps/install" + steps: + - uses: actions/checkout@v4 + + - name: Fix git permissions + # Usually actions/checkout does this but as we run in a container + # it doesn't work + run: git config --global --add safe.directory ${GITHUB_WORKSPACE} + + - name: Update velox + run: | + cd presto-native-execution + make velox-submodule + + - name: Install Arrow Flight + run: | + mkdir -p ${DEPENDENCY_DIR}/adapter-deps/download + mkdir -p ${INSTALL_PREFIX}/adapter-deps/install + source /opt/rh/gcc-toolset-12/enable + set -xu + cd presto-native-execution + PROMPT_ALWAYS_RESPOND=n ./scripts/setup-adapters.sh arrow_flight + + - name: Install Github CLI for using apache/infrastructure-actions/stash + run: | + curl -L https://github.com/cli/cli/releases/download/v2.63.2/gh_2.63.2_linux_amd64.rpm > gh_2.63.2_linux_amd64.rpm + rpm -iv gh_2.63.2_linux_amd64.rpm + + - uses: apache/infrastructure-actions/stash/restore@4ab8682fbd4623d2b4fc1c98db38aba5091924c3 + with: + path: '${{ env.CCACHE_DIR }}' + key: ccache-prestocpp-linux-build-for-test + + - name: Zero ccache statistics + run: ccache -sz + + - name: Build engine + run: | + source /opt/rh/gcc-toolset-12/enable + cd presto-native-execution + cmake \ + -B _build/release \ + -GNinja \ + -DTREAT_WARNINGS_AS_ERRORS=1 \ + -DENABLE_ALL_WARNINGS=1 \ + -DVELOX_ENABLE_ARROW=OFF \ + -DVELOX_ENABLE_PARQUET=OFF \ + -DPRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR=ON \ + -DCMAKE_PREFIX_PATH=/usr/local \ + -DThrift_ROOT=/usr/local \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + -DMAX_LINK_JOBS=4 + ninja -C _build/release -j 4 + + - name: Ccache after + run: ccache -s + + - uses: apache/infrastructure-actions/stash/save@4ab8682fbd4623d2b4fc1c98db38aba5091924c3 + with: + path: '${{ env.CCACHE_DIR }}' + key: ccache-prestocpp-linux-build-for-test + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + name: arrow-flight-presto-native-build + path: presto-native-execution/_build/release/presto_cpp/main/presto_server + + - name: Upload Arrow Flight install artifacts + uses: actions/upload-artifact@v4 + with: + name: arrow-flight-install + path: ${{ env.INSTALL_PREFIX }}/lib64/libarrow_flight* + + arrowflight-native-e2e-tests: + needs: prestocpp-linux-build-for-test + runs-on: ubuntu-22.04 + container: + image: prestodb/presto-native-dependency:0.292-20250204112033-cf8ba84 + env: + INSTALL_PREFIX: "${{ github.workspace }}/adapter-deps/install" + strategy: + fail-fast: false + matrix: + modules: + - ":presto-base-arrow-flight" # Only run tests for the `presto-base-arrow-flight` module + + timeout-minutes: 80 + concurrency: + group: ${{ github.workflow }}-test-${{ matrix.modules }}-${{ github.event.pull_request.number }} + cancel-in-progress: true + + steps: + - uses: actions/checkout@v4 + + - name: Fix git permissions + # Usually actions/checkout does this but as we run in a container + # it doesn't work + run: git config --global --add safe.directory ${GITHUB_WORKSPACE} + + - name: Download artifacts + uses: actions/download-artifact@v4 + with: + name: arrow-flight-presto-native-build + path: presto-native-execution/_build/release/presto_cpp/main + + - name: Download Arrow Flight install artifacts + uses: actions/download-artifact@v4 + with: + name: arrow-flight-install + path: ${{ env.INSTALL_PREFIX }}/lib64 + + # Permissions are lost when uploading. Details here: https://github.com/actions/upload-artifact/issues/38 + - name: Restore execute permissions and library path + run: | + chmod +x ${GITHUB_WORKSPACE}/presto-native-execution/_build/release/presto_cpp/main/presto_server + # Ensure transitive dependency libboost-iostreams is found. + ldconfig /usr/local/lib + + - name: Install OpenJDK8 + uses: actions/setup-java@v4 + with: + distribution: 'temurin' + java-version: 8 + cache: 'maven' + - name: Download nodejs to maven cache + run: .github/bin/download_nodejs + + - name: Maven install + env: + # Use different Maven options to install. + MAVEN_OPTS: "-Xmx2G -XX:+ExitOnOutOfMemoryError" + run: | + export MAVEN_OPTS="${MAVEN_INSTALL_OPTS}" + ./mvnw install ${MAVEN_FAST_INSTALL} -am -pl ${{ matrix.modules }} + + - name: Run arrowflight native e2e tests + run: | + export PRESTO_SERVER_PATH="${GITHUB_WORKSPACE}/presto-native-execution/_build/release/presto_cpp/main/presto_server" + mvn test \ + ${MAVEN_TEST} \ + -pl ${{ matrix.modules }} \ + -Dtest="TestArrowFlightNativeQueries" \ + -DPRESTO_SERVER=${PRESTO_SERVER_PATH} \ + -DDATA_DIR=${RUNNER_TEMP} \ + -Duser.timezone=America/Bahia_Banderas \ + -T1C diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowFlightQueryRunner.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowFlightQueryRunner.java index 384b95b06ad56..50aa4475e61b3 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowFlightQueryRunner.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowFlightQueryRunner.java @@ -26,9 +26,13 @@ import org.apache.arrow.memory.RootAllocator; import java.io.File; +import java.net.URI; import java.util.Map; import java.util.Optional; +import java.util.function.BiFunction; +import static com.facebook.plugin.arrow.testingConnector.TestingArrowFlightPlugin.ARROW_FLIGHT_CATALOG; +import static com.facebook.plugin.arrow.testingConnector.TestingArrowFlightPlugin.ARROW_FLIGHT_CONNECTOR; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; public class ArrowFlightQueryRunner @@ -40,21 +44,28 @@ private ArrowFlightQueryRunner() public static DistributedQueryRunner createQueryRunner(int flightServerPort) throws Exception { - return createQueryRunner(ImmutableMap.of("arrow-flight.server.port", String.valueOf(flightServerPort))); + return createQueryRunner(flightServerPort, ImmutableMap.of(), ImmutableMap.of(), Optional.empty()); } - private static DistributedQueryRunner createQueryRunner(Map catalogProperties) throws Exception + public static DistributedQueryRunner createQueryRunner( + int flightServerPort, + Map extraProperties, + Map coordinatorProperties, + Optional> externalWorkerLauncher) + throws Exception { - return createQueryRunner(ImmutableMap.of(), catalogProperties); + return createQueryRunner(extraProperties, ImmutableMap.of("arrow-flight.server.port", String.valueOf(flightServerPort)), coordinatorProperties, externalWorkerLauncher); } private static DistributedQueryRunner createQueryRunner( Map extraProperties, - Map catalogProperties) + Map catalogProperties, + Map coordinatorProperties, + Optional> externalWorkerLauncher) throws Exception { Session session = testSessionBuilder() - .setCatalog("arrowflight") + .setCatalog(ARROW_FLIGHT_CATALOG) .setSchema("tpch") .build(); @@ -62,10 +73,14 @@ private static DistributedQueryRunner createQueryRunner( Optional workerCount = getProperty("WORKER_COUNT").map(Integer::parseInt); workerCount.ifPresent(queryRunnerBuilder::setNodeCount); - DistributedQueryRunner queryRunner = queryRunnerBuilder.setExtraProperties(extraProperties).build(); + DistributedQueryRunner queryRunner = queryRunnerBuilder + .setExtraProperties(extraProperties) + .setCoordinatorProperties(coordinatorProperties) + .setExternalWorkerLauncher(externalWorkerLauncher).build(); try { - queryRunner.installPlugin(new TestingArrowFlightPlugin()); + boolean nativeExecution = externalWorkerLauncher.isPresent(); + queryRunner.installPlugin(new TestingArrowFlightPlugin(nativeExecution)); ImmutableMap.Builder properties = ImmutableMap.builder() .putAll(catalogProperties) @@ -73,16 +88,16 @@ private static DistributedQueryRunner createQueryRunner( .put("arrow-flight.server-ssl-enabled", "true") .put("arrow-flight.server-ssl-certificate", "src/test/resources/server.crt"); - queryRunner.createCatalog("arrowflight", "arrow-flight", properties.build()); + queryRunner.createCatalog(ARROW_FLIGHT_CATALOG, ARROW_FLIGHT_CONNECTOR, properties.build()); return queryRunner; } catch (Exception e) { - throw new RuntimeException("Failed to create ArrowQueryRunner", e); + throw new RuntimeException("Failed to create ArrowFlightQueryRunner", e); } } - private static Optional getProperty(String name) + public static Optional getProperty(String name) { String systemPropertyValue = System.getProperty(name); if (systemPropertyValue != null) { @@ -116,7 +131,9 @@ public static void main(String[] args) DistributedQueryRunner queryRunner = createQueryRunner( ImmutableMap.of("http-server.http.port", "8080"), - ImmutableMap.of("arrow-flight.server.port", String.valueOf(9443))); + ImmutableMap.of("arrow-flight.server.port", String.valueOf(9443)), + ImmutableMap.of(), + Optional.empty()); Thread.sleep(10); log.info("======== SERVER STARTED ========"); log.info("\n====\n%s\n====", queryRunner.getCoordinator().getBaseUrl()); diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightNativeQueries.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightNativeQueries.java new file mode 100644 index 0000000000000..22827c7efaaef --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightNativeQueries.java @@ -0,0 +1,387 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.log.Logger; +import com.facebook.plugin.arrow.testingServer.TestingArrowProducer; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.AbstractTestQueryFramework; +import com.facebook.presto.tests.DistributedQueryRunner; +import com.google.common.collect.ImmutableMap; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.Location; +import org.apache.arrow.memory.RootAllocator; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.io.File; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.net.ServerSocket; +import java.net.URI; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.function.BiFunction; + +import static com.facebook.plugin.arrow.ArrowFlightQueryRunner.getProperty; +import static com.facebook.plugin.arrow.testingConnector.TestingArrowFlightPlugin.ARROW_FLIGHT_CATALOG; +import static com.facebook.plugin.arrow.testingConnector.TestingArrowFlightPlugin.ARROW_FLIGHT_CONNECTOR; +import static java.lang.String.format; +import static org.testng.Assert.assertTrue; + +public class TestArrowFlightNativeQueries + extends AbstractTestQueryFramework +{ + private static final Logger log = Logger.get(TestArrowFlightNativeQueries.class); + private RootAllocator allocator; + private int serverPort; + private FlightServer server; + private DistributedQueryRunner arrowFlightQueryRunner; + + @BeforeClass + public void setup() + throws Exception + { + arrowFlightQueryRunner = getDistributedQueryRunner(); + + allocator = new RootAllocator(Long.MAX_VALUE); + Location location = Location.forGrpcTls("localhost", serverPort); + File certChainFile = new File("src/test/resources/server.crt"); + File privateKeyFile = new File("src/test/resources/server.key"); + + server = FlightServer.builder(allocator, location, new TestingArrowProducer(allocator)) + .useTls(certChainFile, privateKeyFile) + .build(); + + server.start(); + log.info("Server listening on port %s", server.getPort()); + } + + @AfterClass(alwaysRun = true) + public void close() + throws InterruptedException + { + arrowFlightQueryRunner.close(); + server.close(); + allocator.close(); + } + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + // Base class initializes query runner first, need to assign server port here + serverPort = findUnusedPort(); + + Path prestoServerPath = Paths.get(getProperty("PRESTO_SERVER") + .orElse("_build/debug/presto_cpp/main/presto_server")) + .toAbsolutePath(); + assertTrue(Files.exists(prestoServerPath), format("Native worker binary at %s not found. Add -DPRESTO_SERVER= to your JVM arguments.", prestoServerPath)); + log.info("Using PRESTO_SERVER binary at %s", prestoServerPath); + + ImmutableMap coordinatorProperties = ImmutableMap.of("native-execution-enabled", "true"); + String flightCertPath = Paths.get("src/test/resources/server.crt").toAbsolutePath().toString(); + + return ArrowFlightQueryRunner.createQueryRunner(serverPort, getNativeWorkerSystemProperties(), coordinatorProperties, getExternalWorkerLauncher(prestoServerPath.toString(), serverPort, flightCertPath)); + } + + @Override + protected FeaturesConfig createFeaturesConfig() + { + return new FeaturesConfig().setNativeExecutionEnabled(true); + } + + @Test + public void testFiltersAndProjections1() + { + assertQuery("SELECT * FROM nation"); + assertQuery("SELECT * FROM nation WHERE nationkey = 4"); + assertQuery("SELECT * FROM nation WHERE nationkey <> 4"); + assertQuery("SELECT * FROM nation WHERE nationkey < 4"); + assertQuery("SELECT * FROM nation WHERE nationkey <= 4"); + assertQuery("SELECT * FROM nation WHERE nationkey > 4"); + assertQuery("SELECT * FROM nation WHERE nationkey >= 4"); + assertQuery("SELECT * FROM nation WHERE nationkey BETWEEN 3 AND 7"); + assertQuery("SELECT * FROM nation WHERE nationkey IN (1, 3, 5)"); + assertQuery("SELECT * FROM nation WHERE nationkey NOT IN (1, 3, 5)"); + assertQuery("SELECT * FROM nation WHERE nationkey NOT IN (1, 8, 11)"); + assertQuery("SELECT * FROM nation WHERE nationkey NOT IN (1, 2, 3)"); + assertQuery("SELECT * FROM nation WHERE nationkey NOT IN (-14, 2)"); + assertQuery("SELECT * FROM nation WHERE nationkey NOT IN (1, 2, 3, 4, 5, 10, 11, 12, 13)"); + } + + @Test + public void testFiltersAndProjections2() + { + assertQuery("SELECT * FROM nation WHERE nationkey NOT BETWEEN 3 AND 7"); + assertQuery("SELECT * FROM nation WHERE nationkey NOT BETWEEN -10 AND 5"); + assertQuery("SELECT * FROM nation WHERE nationkey < 5 OR nationkey > 10"); + assertQuery("SELECT nationkey * 10, nationkey % 5, -nationkey, nationkey / 3 FROM nation"); + assertQuery("SELECT *, nationkey / 3 FROM nation"); + assertQuery("SELECT nationkey IS NULL FROM nation"); + assertQuery("SELECT * FROM nation WHERE name <> 'SAUDI ARABIA'"); + assertQuery("SELECT * FROM nation WHERE name NOT IN ('RUSSIA', 'UNITED STATES', 'CHINA')"); + assertQuery("SELECT * FROM nation WHERE name NOT IN ('aaa', 'bbb', 'ccc', 'ddd')"); + assertQuery("SELECT * FROM nation WHERE name NOT IN ('', ';', 'new country w1th $p3c1@l ch@r@c73r5')"); + assertQuery("SELECT * FROM nation WHERE name NOT BETWEEN 'A' AND 'K'"); // should produce NegatedBytesRange + assertQuery("SELECT * FROM nation WHERE name <= 'B' OR 'G' <= name"); + } + + @Test + public void testFiltersAndProjections3() + { + assertQuery("SELECT * FROM lineitem WHERE shipmode <> 'FOB'"); + assertQuery("SELECT * FROM lineitem WHERE shipmode NOT IN ('RAIL', 'AIR')"); + assertQuery("SELECT * FROM lineitem WHERE shipmode NOT IN ('', 'TRUCK', 'FOB', 'RAIL')"); + + assertQuery("SELECT rand() < 1, random() < 1 FROM nation", "SELECT true, true FROM nation"); + + assertQuery("SELECT * FROM lineitem"); + assertQuery("SELECT ceil(discount), ceiling(discount), floor(discount), abs(discount) FROM lineitem"); + assertQuery("SELECT linenumber IN (2, 4, 6) FROM lineitem"); + assertQuery("SELECT orderdate FROM orders WHERE cast(orderdate as DATE) IN (cast('1997-07-29' as DATE), cast('1993-03-13' as DATE)) ORDER BY orderdate LIMIT 10"); + + assertQuery("SELECT * FROM orders"); + + assertQuery("SELECT coalesce(linenumber, -1) FROM lineitem"); + + assertQuery("SELECT * FROM lineitem WHERE linenumber = 1"); + assertQuery("SELECT * FROM lineitem WHERE linenumber > 3"); + } + + @Test + public void testFiltersAndProjections4() + { + assertQuery("SELECT * FROM lineitem WHERE linenumber = 3"); + assertQuery("SELECT * FROM lineitem WHERE linenumber > 5 AND linenumber < 2"); + + assertQuery("SELECT * FROM lineitem WHERE linenumber > 5"); + assertQuery("SELECT * FROM lineitem WHERE linenumber IN (1, 2)"); + + assertQuery("SELECT linenumber, orderkey, discount FROM lineitem WHERE discount > 0.02"); + assertQuery("SELECT linenumber, orderkey, discount FROM lineitem WHERE discount BETWEEN 0.01 AND 0.02"); + + assertQuery("SELECT linenumber, orderkey, discount FROM lineitem WHERE discount > 0.02"); + assertQuery("SELECT linenumber, orderkey, discount FROM lineitem WHERE discount BETWEEN 0.01 AND 0.02"); + assertQuery("SELECT linenumber, orderkey, discount FROM lineitem WHERE tax < 0.02"); + assertQuery("SELECT linenumber, orderkey, discount FROM lineitem WHERE tax BETWEEN 0.02 AND 0.06"); + } + + @Test + public void testFiltersAndProjections6() + { + // query with filter using like + assertQuery("SELECT * FROM lineitem WHERE shipinstruct like 'TAKE BACK%'"); + assertQuery("SELECT * FROM lineitem WHERE shipinstruct like 'TAKE BACK#%' escape '#'"); + + // no row passes the filter + assertQuery( + "SELECT linenumber, orderkey, discount FROM lineitem WHERE discount > 0.2"); + + // Double and float inequality filter + assertQuery("SELECT SUM(discount) FROM lineitem WHERE discount != 0.04"); + } + + @Test + public void testTopN() + { + assertQueryOrdered("SELECT nationkey, regionkey FROM nation ORDER BY nationkey LIMIT 5"); + + assertQueryOrdered("SELECT nationkey, regionkey FROM nation ORDER BY nationkey LIMIT 50"); + + assertQueryOrdered( + "SELECT orderkey, partkey, suppkey, linenumber, quantity, extendedprice, discount, tax " + + "FROM lineitem ORDER BY orderkey, linenumber DESC LIMIT 10"); + + assertQueryOrdered( + "SELECT orderkey, partkey, suppkey, linenumber, quantity, extendedprice, discount, tax " + + "FROM lineitem ORDER BY orderkey, linenumber DESC LIMIT 2000"); + + assertQueryOrdered("SELECT nationkey, regionkey FROM nation ORDER BY name LIMIT 15"); + assertQueryOrdered("SELECT nationkey, regionkey FROM nation ORDER BY name DESC LIMIT 15"); + + assertQuery("SELECT linenumber, NULL FROM lineitem ORDER BY 1 LIMIT 23"); + } + + @Test + public void testCast() + { + assertQuery("SELECT CAST(linenumber as TINYINT), CAST(linenumber AS SMALLINT), " + + "CAST(linenumber AS INTEGER), CAST(linenumber AS BIGINT), CAST(quantity AS REAL), " + + "CAST(orderkey AS DOUBLE), CAST(orderkey AS VARCHAR) FROM lineitem"); + + assertQuery("SELECT CAST(0.0 as VARCHAR)"); + + // Cast to varchar(n). + assertQuery("SELECT CAST(comment as VARCHAR(1)) FROM orders"); + assertQuery("SELECT CAST(comment as VARCHAR(1000)) FROM orders WHERE LENGTH(comment) < 1000"); + assertQuery("SELECT CAST(c0 AS VARCHAR(1)) FROM ( VALUES (NULL) ) t(c0)"); + assertQuery("SELECT CAST(c0 AS VARCHAR(1)) FROM ( VALUES ('') ) t(c0)"); + + assertQuery("SELECT CAST(linenumber as TINYINT), CAST(linenumber AS SMALLINT), " + + "CAST(linenumber AS INTEGER), CAST(linenumber AS BIGINT), CAST(quantity AS REAL), " + + "CAST(orderkey AS DOUBLE), CAST(orderkey AS VARCHAR) FROM lineitem"); + + // Casts to varbinary. + assertQuery("SELECT cast(null as varbinary)"); + assertQuery("SELECT cast('' as varbinary)"); + + // Ensure timestamp casts are correct. + assertQuery("SELECT cast(cast(shipdate as varchar) as timestamp) FROM lineitem ORDER BY 1"); + + // Ensure date casts are correct. + assertQuery("SELECT cast(cast(orderdate as varchar) as date) FROM orders ORDER BY 1"); + + // Cast all integer types to short decimal + assertQuery("SELECT CAST(linenumber as DECIMAL(2, 0)) FROM lineitem"); + assertQuery("SELECT CAST(linenumber as DECIMAL(8, 4)) FROM lineitem"); + assertQuery("SELECT CAST(CAST(linenumber as INTEGER) as DECIMAL(15, 6)) FROM lineitem"); + assertQuery("SELECT CAST(nationkey as DECIMAL(18, 6)) FROM nation"); + + // Cast all integer types to long decimal + assertQuery("SELECT CAST(linenumber as DECIMAL(25, 0)) FROM lineitem"); + assertQuery("SELECT CAST(linenumber as DECIMAL(19, 4)) FROM lineitem"); + assertQuery("SELECT CAST(CAST(linenumber as INTEGER) as DECIMAL(20, 6)) FROM lineitem"); + assertQuery("SELECT CAST(nationkey as DECIMAL(22, 6)) FROM nation"); + } + + @Test + public void testSwitch() + { + assertQuery("SELECT case linenumber % 10 when orderkey % 3 then orderkey + 1 when 2 then orderkey + 2 else 0 end FROM lineitem"); + assertQuery("SELECT case linenumber when 1 then 'one' when 2 then 'two' else '...' end FROM lineitem"); + assertQuery("SELECT case when linenumber = 1 then 'one' when linenumber = 2 then 'two' else '...' end FROM lineitem"); + } + + @Test + public void testIn() + { + assertQuery("SELECT linenumber IN (orderkey % 7, partkey % 5, suppkey % 3) FROM lineitem"); + } + + @Test + public void testSubqueries() + { + assertQuery("SELECT name FROM nation WHERE regionkey = (SELECT max(regionkey) FROM region)"); + + // Subquery returns zero rows. + assertQuery("SELECT name FROM nation WHERE regionkey = (SELECT regionkey FROM region WHERE regionkey < 0)"); + + // Subquery returns more than one row. + assertQueryFails("SELECT name FROM nation WHERE regionkey = (SELECT regionkey FROM region)", ".*Expected single row of input. Received 5 rows.*"); + } + + @Test + public void testArithmetic() + { + assertQuery("SELECT mod(orderkey, linenumber) FROM lineitem"); + assertQuery("SELECT discount * 0.123 FROM lineitem"); + assertQuery("SELECT ln(totalprice) FROM orders"); + assertQuery("SELECT sqrt(totalprice) FROM orders"); + assertQuery("SELECT radians(totalprice) FROM orders"); + } + + @Test + public void testGreatestLeast() + { + assertQuery("SELECT greatest(linenumber, suppkey, partkey) from lineitem"); + assertQuery("SELECT least(shipdate, commitdate) from lineitem"); + } + + @Test + public void testSign() + { + assertQuery("SELECT sign(totalprice) from orders"); + assertQuery("SELECT sign(-totalprice) from orders"); + assertQuery("SELECT sign(custkey) from orders"); + assertQuery("SELECT sign(-custkey) from orders"); + assertQuery("SELECT sign(shippriority) from orders"); + } + + private static int findUnusedPort() + throws IOException + { + try (ServerSocket socket = new ServerSocket(0)) { + return socket.getLocalPort(); + } + } + + public static Map getNativeWorkerSystemProperties() + { + return ImmutableMap.builder() + .put("native-execution-enabled", "true") + .put("optimizer.optimize-hash-generation", "false") + .put("regex-library", "RE2J") + .put("offset-clause-enabled", "true") + // By default, Presto will expand some functions into its SQL equivalent (e.g. array_duplicates()). + // With Velox, we do not want Presto to replace the function with its SQL equivalent. + // To achieve that, we set inline-sql-functions to false. + .put("inline-sql-functions", "false") + .put("use-alternative-function-signatures", "true") + .build(); + } + + public static Optional> getExternalWorkerLauncher(String prestoServerPath, int flightServerPort, String flightCertPath) + { + return + Optional.of((workerIndex, discoveryUri) -> { + try { + Path dir = Paths.get("/tmp", TestArrowFlightNativeQueries.class.getSimpleName()); + Files.createDirectories(dir); + Path tempDirectoryPath = Files.createTempDirectory(dir, "worker"); + log.info("Temp directory for Worker #%d: %s", workerIndex, tempDirectoryPath.toString()); + + // Write config file - use an ephemeral port for the worker. + String configProperties = format("discovery.uri=%s%n" + + "presto.version=testversion%n" + + "system-memory-gb=4%n" + + "http-server.http.port=0%n", discoveryUri); + + Files.write(tempDirectoryPath.resolve("config.properties"), configProperties.getBytes()); + Files.write(tempDirectoryPath.resolve("node.properties"), + format("node.id=%s%n" + + "node.internal-address=127.0.0.1%n" + + "node.environment=testing%n" + + "node.location=test-location", UUID.randomUUID()).getBytes()); + + Path catalogDirectoryPath = tempDirectoryPath.resolve("catalog"); + Files.createDirectory(catalogDirectoryPath); + + Files.write(catalogDirectoryPath.resolve(format("%s.properties", ARROW_FLIGHT_CATALOG)), + format("connector.name=%s\n" + + "arrow-flight.server=localhost\n" + + "arrow-flight.server.port=%d\n" + + "arrow-flight.server-ssl-enabled=true\n" + + "arrow-flight.server-ssl-certificate=%s", ARROW_FLIGHT_CONNECTOR, flightServerPort, flightCertPath).getBytes()); + + // Disable stack trace capturing as some queries (using TRY) generate a lot of exceptions. + return new ProcessBuilder(prestoServerPath, "--logtostderr=1", "--v=1") + .directory(tempDirectoryPath.toFile()) + .redirectErrorStream(true) + .redirectOutput(ProcessBuilder.Redirect.to(tempDirectoryPath.resolve("worker." + workerIndex + ".out").toFile())) + .redirectError(ProcessBuilder.Redirect.to(tempDirectoryPath.resolve("worker." + workerIndex + ".out").toFile())) + .start(); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + }); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowFlightPlugin.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowFlightPlugin.java index 830b5b04b3b5e..edf0e83bfa9a3 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowFlightPlugin.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowFlightPlugin.java @@ -19,8 +19,16 @@ public class TestingArrowFlightPlugin extends ArrowPlugin { + public static final String ARROW_FLIGHT_CATALOG = "arrowflight"; + public static final String ARROW_FLIGHT_CONNECTOR = "arrow-flight"; + + public TestingArrowFlightPlugin(boolean nativeExecution) + { + super(ARROW_FLIGHT_CONNECTOR, new TestingArrowModule(nativeExecution), new JsonModule()); + } + public TestingArrowFlightPlugin() { - super("arrow-flight", new TestingArrowModule(), new JsonModule()); + this(false); } } diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowModule.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowModule.java index de9bff4cb83a9..04e1ec34f4d2a 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowModule.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowModule.java @@ -26,13 +26,22 @@ public class TestingArrowModule implements Module { + private final boolean nativeExecution; + + public TestingArrowModule(boolean nativeExecution) + { + this.nativeExecution = nativeExecution; + } + @Override public void configure(Binder binder) { // Concrete implementation of the BaseFlightClientHandler binder.bind(BaseArrowFlightClientHandler.class).to(TestingArrowFlightClientHandler.class).in(Scopes.SINGLETON); - // Override the ArrowBlockBuilder with an implementation that handles h2 types - binder.bind(ArrowBlockBuilder.class).to(TestingArrowBlockBuilder.class).in(Scopes.SINGLETON); + // Override the ArrowBlockBuilder with an implementation that handles h2 types, skip for native + if (!nativeExecution) { + binder.bind(ArrowBlockBuilder.class).to(TestingArrowBlockBuilder.class).in(Scopes.SINGLETON); + } // Request/response objects jsonCodecBinder(binder).bindJsonCodec(TestingArrowFlightRequest.class); jsonCodecBinder(binder).bindJsonCodec(TestingArrowFlightResponse.class); diff --git a/presto-native-execution/CMakeLists.txt b/presto-native-execution/CMakeLists.txt index d5001dde70a74..6ac789b73e3a1 100644 --- a/presto-native-execution/CMakeLists.txt +++ b/presto-native-execution/CMakeLists.txt @@ -63,6 +63,8 @@ option(PRESTO_ENABLE_TESTING "Enable tests" ON) option(PRESTO_ENABLE_JWT "Enable JWT (JSON Web Token) authentication" OFF) +option(PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR "Enable Arrow Flight connector" OFF) + # Set all Velox options below add_compile_definitions(FOLLY_HAVE_INT128_T=1) diff --git a/presto-native-execution/README.md b/presto-native-execution/README.md index cccebfcfb8d03..1976be406c2e0 100644 --- a/presto-native-execution/README.md +++ b/presto-native-execution/README.md @@ -115,6 +115,15 @@ follow these steps: * For development, use `make debug` to build a non-optimized debug version. * Use `make unittest` to build and run tests. +#### Arrow Flight Connector +To enable Arrow Flight connector support, add to the extra cmake flags: +`EXTRA_CMAKE_FLAGS = -DPRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR=ON` + +The Arrow Flight connector requires the Arrow Flight library. You can install this dependency +by running the following script from the `presto/presto-native-execution` directory: + +`./scripts/setup-adapters.sh arrow_flight` + ### Makefile Targets A reminder of the available Makefile targets can be obtained using `make help` ``` diff --git a/presto-native-execution/presto_cpp/main/CMakeLists.txt b/presto-native-execution/presto_cpp/main/CMakeLists.txt index c06e00edf834c..1d6abc257f876 100644 --- a/presto-native-execution/presto_cpp/main/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/CMakeLists.txt @@ -14,6 +14,7 @@ add_subdirectory(types) add_subdirectory(http) add_subdirectory(common) add_subdirectory(thrift) +add_subdirectory(connectors) add_library( presto_server_lib @@ -29,7 +30,6 @@ add_library( QueryContextManager.cpp ServerOperation.cpp SignalHandler.cpp - SystemConnector.cpp SessionProperties.cpp TaskManager.cpp TaskResource.cpp @@ -48,6 +48,7 @@ target_link_libraries( presto_common presto_exception presto_function_metadata + presto_connectors presto_http presto_operators presto_velox_conversion diff --git a/presto-native-execution/presto_cpp/main/PrestoServer.cpp b/presto-native-execution/presto_cpp/main/PrestoServer.cpp index 94364d0b1c1f6..10f3a2d401645 100644 --- a/presto-native-execution/presto_cpp/main/PrestoServer.cpp +++ b/presto-native-execution/presto_cpp/main/PrestoServer.cpp @@ -21,11 +21,12 @@ #include "presto_cpp/main/PeriodicMemoryChecker.h" #include "presto_cpp/main/PeriodicTaskManager.h" #include "presto_cpp/main/SignalHandler.h" -#include "presto_cpp/main/SystemConnector.h" #include "presto_cpp/main/TaskResource.h" #include "presto_cpp/main/common/ConfigReader.h" #include "presto_cpp/main/common/Counters.h" #include "presto_cpp/main/common/Utils.h" +#include "presto_cpp/main/connectors/Registration.h" +#include "presto_cpp/main/connectors/SystemConnector.h" #include "presto_cpp/main/http/HttpConstants.h" #include "presto_cpp/main/http/filters/AccessLogFilter.h" #include "presto_cpp/main/http/filters/HttpEndpointLatencyFilter.h" @@ -48,13 +49,11 @@ #include "velox/common/memory/MmapAllocator.h" #include "velox/common/memory/SharedArbitrator.h" #include "velox/connectors/Connector.h" -#include "velox/connectors/hive/HiveConnector.h" #include "velox/connectors/hive/HiveDataSink.h" #include "velox/connectors/hive/storage_adapters/abfs/RegisterAbfsFileSystem.h" #include "velox/connectors/hive/storage_adapters/gcs/RegisterGcsFileSystem.h" #include "velox/connectors/hive/storage_adapters/hdfs/RegisterHdfsFileSystem.h" #include "velox/connectors/hive/storage_adapters/s3fs/RegisterS3FileSystem.h" -#include "velox/connectors/tpch/TpchConnector.h" #include "velox/dwio/dwrf/RegisterDwrfReader.h" #include "velox/dwio/dwrf/RegisterDwrfWriter.h" #include "velox/dwio/orc/reader/OrcReader.h" @@ -88,7 +87,6 @@ constexpr char const* kHttps = "https"; constexpr char const* kTaskUriFormat = "{}://{}:{}"; // protocol, address and port constexpr char const* kConnectorName = "connector.name"; -constexpr char const* kHiveHadoop2ConnectorName = "hive-hadoop2"; protocol::NodeState convertNodeState(presto::NodeState nodeState) { switch (nodeState) { @@ -255,38 +253,14 @@ void PrestoServer::run() { registerMemoryArbitrators(); registerShuffleInterfaceFactories(); registerCustomOperators(); - registerConnectorFactories(); - - // Register Velox connector factory for iceberg. - // The iceberg catalog is handled by the hive connector factory. - velox::connector::registerConnectorFactory( - std::make_shared( - "iceberg")); - - registerPrestoToVeloxConnector( - std::make_unique("hive")); - registerPrestoToVeloxConnector( - std::make_unique("hive-hadoop2")); - registerPrestoToVeloxConnector( - std::make_unique("iceberg")); - registerPrestoToVeloxConnector( - std::make_unique("tpch")); - // Presto server uses system catalog or system schema in other catalogs - // in different places in the code. All these resolve to the SystemConnector. - // Depending on where the operator or column is used, different prefixes can - // be used in the naming. So the protocol class is mapped - // to all the different prefixes for System tables/columns. - registerPrestoToVeloxConnector( - std::make_unique("$system")); - registerPrestoToVeloxConnector( - std::make_unique("system")); - registerPrestoToVeloxConnector( - std::make_unique("$system@system")); + + // Register Presto connector factories and connectors + registerConnectors(); initializeVeloxMemory(); initializeThreadPools(); - auto catalogNames = registerConnectors(fs::path(configDirectoryPath_)); + auto catalogNames = registerVeloxConnectors(fs::path(configDirectoryPath_)); const bool bindToNodeInternalAddressOnly = systemConfig->httpServerBindToNodeInternalAddressOnlyEnabled(); @@ -1179,25 +1153,7 @@ PrestoServer::getAdditionalHttpServerFilters() { return filters; } -void PrestoServer::registerConnectorFactories() { - // These checks for connector factories can be removed after we remove the - // registrations from the Velox library. - if (!velox::connector::hasConnectorFactory( - velox::connector::hive::HiveConnectorFactory::kHiveConnectorName)) { - velox::connector::registerConnectorFactory( - std::make_shared()); - velox::connector::registerConnectorFactory( - std::make_shared( - kHiveHadoop2ConnectorName)); - } - if (!velox::connector::hasConnectorFactory( - velox::connector::tpch::TpchConnectorFactory::kTpchConnectorName)) { - velox::connector::registerConnectorFactory( - std::make_shared()); - } -} - -std::vector PrestoServer::registerConnectors( +std::vector PrestoServer::registerVeloxConnectors( const fs::path& configDirectoryPath) { static const std::string kPropertiesExtension = ".properties"; diff --git a/presto-native-execution/presto_cpp/main/PrestoServer.h b/presto-native-execution/presto_cpp/main/PrestoServer.h index cf346777d5b4f..d8c1a75628acd 100644 --- a/presto-native-execution/presto_cpp/main/PrestoServer.h +++ b/presto-native-execution/presto_cpp/main/PrestoServer.h @@ -135,7 +135,7 @@ class PrestoServer { virtual std::vector> getAdditionalHttpServerFilters(); - virtual std::vector registerConnectors( + virtual std::vector registerVeloxConnectors( const fs::path& configDirectoryPath); /// Invoked to register the required dwio data sinks which are used by @@ -146,8 +146,6 @@ class PrestoServer { virtual void unregisterFileReadersAndWriters(); - virtual void registerConnectorFactories(); - /// Invoked by presto shutdown procedure to unregister connectors. virtual void unregisterConnectors(); diff --git a/presto-native-execution/presto_cpp/main/QueryContextManager.cpp b/presto-native-execution/presto_cpp/main/QueryContextManager.cpp index c230ca84363c8..b784fce193349 100644 --- a/presto-native-execution/presto_cpp/main/QueryContextManager.cpp +++ b/presto-native-execution/presto_cpp/main/QueryContextManager.cpp @@ -68,10 +68,10 @@ void updateFromSystemConfigs( } std::unordered_map> -toConnectorConfigs(const protocol::SessionRepresentation& session) { +toConnectorConfigs(const protocol::TaskUpdateRequest& taskUpdateRequest) { std::unordered_map> connectorConfigs; - for (const auto& entry : session.catalogProperties) { + for (const auto& entry : taskUpdateRequest.session.catalogProperties) { std::unordered_map connectorConfig; // remove native prefix from native connector session property names for (const auto& sessionProperty : entry.second) { @@ -80,6 +80,10 @@ toConnectorConfigs(const protocol::SessionRepresentation& session) { : sessionProperty.first; connectorConfig.emplace(veloxConfig, sessionProperty.second); } + connectorConfig.insert( + taskUpdateRequest.extraCredentials.begin(), + taskUpdateRequest.extraCredentials.end()); + connectorConfig.insert({"user", taskUpdateRequest.session.user}); connectorConfigs.insert( {entry.first, connectorConfig}); } @@ -120,9 +124,11 @@ QueryContextManager::QueryContextManager( std::shared_ptr QueryContextManager::findOrCreateQueryCtx( const protocol::TaskId& taskId, - const protocol::SessionRepresentation& session) { + const protocol::TaskUpdateRequest& taskUpdateRequest) { return findOrCreateQueryCtx( - taskId, toVeloxConfigs(session), toConnectorConfigs(session)); + taskId, + toVeloxConfigs(taskUpdateRequest.session), + toConnectorConfigs(taskUpdateRequest)); } std::shared_ptr QueryContextManager::findOrCreateQueryCtx( diff --git a/presto-native-execution/presto_cpp/main/QueryContextManager.h b/presto-native-execution/presto_cpp/main/QueryContextManager.h index 16d97f551ee3d..f8b1a1836ce55 100644 --- a/presto-native-execution/presto_cpp/main/QueryContextManager.h +++ b/presto-native-execution/presto_cpp/main/QueryContextManager.h @@ -107,7 +107,7 @@ class QueryContextManager { std::shared_ptr findOrCreateQueryCtx( const protocol::TaskId& taskId, - const protocol::SessionRepresentation& session); + const protocol::TaskUpdateRequest& taskUpdateRequest); /// Calls the given functor for every present query context. void visitAllContexts(std::functionfindOrCreateQueryCtx( - taskId, updateRequest.session); + taskId, updateRequest); VeloxBatchQueryPlanConverter converter( shuffleName, @@ -340,7 +340,7 @@ proxygen::RequestHandler* TaskResource::createOrUpdateTask( queryCtx = taskManager_.getQueryContextManager()->findOrCreateQueryCtx( - taskId, updateRequest.session); + taskId, updateRequest); VeloxInteractiveQueryPlanConverter converter(queryCtx.get(), pool_); planFragment = converter.toVeloxQueryPlan( diff --git a/presto-native-execution/presto_cpp/main/connectors/CMakeLists.txt b/presto-native-execution/presto_cpp/main/connectors/CMakeLists.txt new file mode 100644 index 0000000000000..7c627b844a822 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/CMakeLists.txt @@ -0,0 +1,20 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +add_library(presto_connectors Registration.cpp PrestoToVeloxConnector.cpp + SystemConnector.cpp) + +if(PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR) + add_subdirectory(arrow_flight) + target_link_libraries(presto_connectors presto_flight_connector) +endif() + +target_link_libraries(presto_connectors presto_types) diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxConnector.cpp b/presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnector.cpp similarity index 99% rename from presto-native-execution/presto_cpp/main/types/PrestoToVeloxConnector.cpp rename to presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnector.cpp index c525f88e35300..916bf605913a3 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxConnector.cpp +++ b/presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnector.cpp @@ -12,7 +12,9 @@ * limitations under the License. */ -#include "presto_cpp/main/types/PrestoToVeloxConnector.h" +#include "presto_cpp/main/connectors/PrestoToVeloxConnector.h" +#include "presto_cpp/main/types/PrestoToVeloxExpr.h" +#include "presto_cpp/main/types/TypeParser.h" #include "presto_cpp/presto_protocol/connector/hive/HiveConnectorProtocol.h" #include "presto_cpp/presto_protocol/connector/iceberg/IcebergConnectorProtocol.h" #include "presto_cpp/presto_protocol/connector/tpch/TpchConnectorProtocol.h" diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxConnector.h b/presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnector.h similarity index 99% rename from presto-native-execution/presto_cpp/main/types/PrestoToVeloxConnector.h rename to presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnector.h index eb33dfb54ca1d..eed81e4cc00f3 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxConnector.h +++ b/presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnector.h @@ -13,8 +13,6 @@ */ #pragma once -#include "PrestoToVeloxExpr.h" -#include "presto_cpp/main/types/TypeParser.h" #include "presto_cpp/presto_protocol/connector/hive/presto_protocol_hive.h" #include "presto_cpp/presto_protocol/core/ConnectorProtocol.h" #include "velox/connectors/Connector.h" @@ -25,6 +23,8 @@ namespace facebook::presto { class PrestoToVeloxConnector; +class TypeParser; +class VeloxExprConverter; void registerPrestoToVeloxConnector( std::unique_ptr connector); diff --git a/presto-native-execution/presto_cpp/main/connectors/Registration.cpp b/presto-native-execution/presto_cpp/main/connectors/Registration.cpp new file mode 100644 index 0000000000000..d6f6555fb8a22 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/Registration.cpp @@ -0,0 +1,95 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/Registration.h" +#include "presto_cpp/main/connectors/SystemConnector.h" + +#ifdef PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h" +#include "presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.h" +#endif + +#include "velox/connectors/hive/HiveConnector.h" +#include "velox/connectors/tpch/TpchConnector.h" + +namespace facebook::presto { +namespace { + +constexpr char const* kHiveHadoop2ConnectorName = "hive-hadoop2"; +constexpr char const* kIcebergConnectorName = "iceberg"; + +void registerConnectorFactories() { + // These checks for connector factories can be removed after we remove the + // registrations from the Velox library. + if (!velox::connector::hasConnectorFactory( + velox::connector::hive::HiveConnectorFactory::kHiveConnectorName)) { + velox::connector::registerConnectorFactory( + std::make_shared()); + velox::connector::registerConnectorFactory( + std::make_shared( + kHiveHadoop2ConnectorName)); + } + if (!velox::connector::hasConnectorFactory( + velox::connector::tpch::TpchConnectorFactory::kTpchConnectorName)) { + velox::connector::registerConnectorFactory( + std::make_shared()); + } + + // Register Velox connector factory for iceberg. + // The iceberg catalog is handled by the hive connector factory. + if (!velox::connector::hasConnectorFactory(kIcebergConnectorName)) { + velox::connector::registerConnectorFactory( + std::make_shared( + kIcebergConnectorName)); + } + +#ifdef PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR + if (!velox::connector::hasConnectorFactory( + ArrowFlightConnectorFactory::kArrowFlightConnectorName)) { + velox::connector::registerConnectorFactory( + std::make_shared()); + } +#endif +} +} // namespace + +void registerConnectors() { + registerConnectorFactories(); + + registerPrestoToVeloxConnector(std::make_unique( + velox::connector::hive::HiveConnectorFactory::kHiveConnectorName)); + registerPrestoToVeloxConnector( + std::make_unique(kHiveHadoop2ConnectorName)); + registerPrestoToVeloxConnector( + std::make_unique(kIcebergConnectorName)); + registerPrestoToVeloxConnector(std::make_unique( + velox::connector::tpch::TpchConnectorFactory::kTpchConnectorName)); + + // Presto server uses system catalog or system schema in other catalogs + // in different places in the code. All these resolve to the SystemConnector. + // Depending on where the operator or column is used, different prefixes can + // be used in the naming. So the protocol class is mapped + // to all the different prefixes for System tables/columns. + registerPrestoToVeloxConnector( + std::make_unique("$system")); + registerPrestoToVeloxConnector( + std::make_unique("system")); + registerPrestoToVeloxConnector( + std::make_unique("$system@system")); + +#ifdef PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR + registerPrestoToVeloxConnector(std::make_unique( + ArrowFlightConnectorFactory::kArrowFlightConnectorName)); +#endif +} +} // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/connectors/Registration.h b/presto-native-execution/presto_cpp/main/connectors/Registration.h new file mode 100644 index 0000000000000..c95aefaacfcaa --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/Registration.h @@ -0,0 +1,20 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +namespace facebook::presto { + +void registerConnectors(); + +} // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/SystemConnector.cpp b/presto-native-execution/presto_cpp/main/connectors/SystemConnector.cpp similarity index 99% rename from presto-native-execution/presto_cpp/main/SystemConnector.cpp rename to presto-native-execution/presto_cpp/main/connectors/SystemConnector.cpp index 7622d203e8689..eb9fb48196e9d 100644 --- a/presto-native-execution/presto_cpp/main/SystemConnector.cpp +++ b/presto-native-execution/presto_cpp/main/connectors/SystemConnector.cpp @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "presto_cpp/main/SystemConnector.h" +#include "presto_cpp/main/connectors/SystemConnector.h" #include "presto_cpp/main/PrestoTask.h" #include "presto_cpp/main/TaskManager.h" diff --git a/presto-native-execution/presto_cpp/main/SystemConnector.h b/presto-native-execution/presto_cpp/main/connectors/SystemConnector.h similarity index 98% rename from presto-native-execution/presto_cpp/main/SystemConnector.h rename to presto-native-execution/presto_cpp/main/connectors/SystemConnector.h index 52d9df595f736..e7ffd7f2519b6 100644 --- a/presto-native-execution/presto_cpp/main/SystemConnector.h +++ b/presto-native-execution/presto_cpp/main/connectors/SystemConnector.h @@ -13,13 +13,12 @@ */ #pragma once -#include "presto_cpp/main/SystemSplit.h" -#include "presto_cpp/main/types/PrestoToVeloxConnector.h" +#include "presto_cpp/main/connectors/PrestoToVeloxConnector.h" +#include "presto_cpp/main/connectors/SystemSplit.h" #include "velox/connectors/Connector.h" namespace facebook::presto { - class TaskManager; class SystemColumnHandle : public velox::connector::ColumnHandle { diff --git a/presto-native-execution/presto_cpp/main/SystemSplit.h b/presto-native-execution/presto_cpp/main/connectors/SystemSplit.h similarity index 100% rename from presto-native-execution/presto_cpp/main/SystemSplit.h rename to presto-native-execution/presto_cpp/main/connectors/SystemSplit.h diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConfig.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConfig.cpp new file mode 100644 index 0000000000000..c0900c6c41de4 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConfig.cpp @@ -0,0 +1,46 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConfig.h" + +namespace facebook::presto { + +std::string ArrowFlightConfig::authenticatorName() const { + return config_->get(kAuthenticatorName, "none"); +} + +std::optional ArrowFlightConfig::defaultServerHostname() const { + return static_cast>( + config_->get(kDefaultServerHost)); +} + +std::optional ArrowFlightConfig::defaultServerPort() const { + return static_cast>( + config_->get(kDefaultServerPort)); +} + +bool ArrowFlightConfig::defaultServerSslEnabled() const { + return config_->get(kDefaultServerSslEnabled, false); +} + +bool ArrowFlightConfig::serverVerify() const { + return config_->get(kServerVerify, true); +} + +std::optional ArrowFlightConfig::serverSslCertificate() const { + return static_cast>( + config_->get(kServerSslCertificate)); +} + +} // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConfig.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConfig.h new file mode 100644 index 0000000000000..77ad8e9379cf3 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConfig.h @@ -0,0 +1,57 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/common/config/Config.h" + +namespace facebook::presto { + +class ArrowFlightConfig { + public: + explicit ArrowFlightConfig( + std::shared_ptr config) + : config_{config} {} + + static constexpr const char* kAuthenticatorName = + "arrow-flight.authenticator.name"; + + static constexpr const char* kDefaultServerHost = "arrow-flight.server"; + + static constexpr const char* kDefaultServerPort = "arrow-flight.server.port"; + + static constexpr const char* kDefaultServerSslEnabled = + "arrow-flight.server-ssl-enabled"; + + static constexpr const char* kServerVerify = "arrow-flight.server.verify"; + + static constexpr const char* kServerSslCertificate = + "arrow-flight.server-ssl-certificate"; + + std::string authenticatorName() const; + + std::optional defaultServerHostname() const; + + std::optional defaultServerPort() const; + + bool defaultServerSslEnabled() const; + + bool serverVerify() const; + + std::optional serverSslCertificate() const; + + private: + const std::shared_ptr config_; +}; + +} // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.cpp new file mode 100644 index 0000000000000..6aacbd339228f --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.cpp @@ -0,0 +1,212 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h" +#include +#include +#include +#include +#include +#include "presto_cpp/main/common/ConfigReader.h" +#include "presto_cpp/main/connectors/arrow_flight/Macros.h" +#include "velox/vector/arrow/Bridge.h" + +using namespace facebook::velox::connector; + +namespace facebook::presto { +namespace { +std::shared_ptr getDefaultLocation( + const std::shared_ptr& config) { + auto defaultHost = config->defaultServerHostname(); + auto defaultPort = config->defaultServerPort(); + if (!defaultHost.has_value() || !defaultPort.has_value()) { + return nullptr; + } + + AFC_ASSIGN_OR_RAISE( + auto defaultLocation, + config->defaultServerSslEnabled() + ? arrow::flight::Location::ForGrpcTls( + defaultHost.value(), defaultPort.value()) + : arrow::flight::Location::ForGrpcTcp( + defaultHost.value(), defaultPort.value())); + + return std::make_shared(std::move(defaultLocation)); +} +} // namespace + +// Wrapper for CallOptions which does not add any member variables, +// but provides a write-only interface for adding call headers. +class CallOptionsAddHeaders : public arrow::flight::FlightCallOptions, + public arrow::flight::AddCallHeaders { + public: + void AddHeader(const std::string& key, const std::string& value) override { + headers.emplace_back(key, value); + } +}; + +std::shared_ptr +ArrowFlightConnector::initClientOpts( + const std::shared_ptr& config) { + auto clientOpts = std::make_shared(); + clientOpts->disable_server_verification = !config->serverVerify(); + + auto certPath = config->serverSslCertificate(); + if (certPath.has_value()) { + std::ifstream file(certPath.value()); + VELOX_CHECK(file.is_open(), "Could not open TLS certificate"); + std::string cert( + (std::istreambuf_iterator(file)), + (std::istreambuf_iterator())); + clientOpts->tls_root_certs = cert; + } + + return clientOpts; +} + +ArrowFlightDataSource::ArrowFlightDataSource( + const velox::RowTypePtr& outputType, + const std::unordered_map>& + columnHandles, + std::shared_ptr authenticator, + const ConnectorQueryCtx* connectorQueryCtx, + const std::shared_ptr& flightConfig, + const std::shared_ptr& clientOpts) + : outputType_{outputType}, + authenticator_{std::move(authenticator)}, + connectorQueryCtx_{connectorQueryCtx}, + flightConfig_{flightConfig}, + clientOpts_{clientOpts}, + defaultLocation_(getDefaultLocation(flightConfig_)) { + VELOX_CHECK_NOT_NULL(clientOpts_, "FlightClientOptions is not initialized"); + + // columnMapping_ contains the real column names in the expected order. + // This is later used by projectOutputColumns to filter out unnecessary + // columns from the fetched chunk. + columnMapping_.reserve(outputType_->size()); + + for (const auto& columnName : outputType_->names()) { + auto it = columnHandles.find(columnName); + VELOX_CHECK( + it != columnHandles.end(), + "missing columnHandle for column '{}'", + columnName); + + auto handle = + std::dynamic_pointer_cast(it->second); + VELOX_CHECK_NOT_NULL( + handle, + "handle for column '{}' is not an ArrowFlightColumnHandle", + columnName); + + columnMapping_.push_back(handle->name()); + } +} + +void ArrowFlightDataSource::addSplit(std::shared_ptr split) { + auto flightSplit = std::dynamic_pointer_cast(split); + VELOX_CHECK( + flightSplit, "ArrowFlightDataSource received wrong type of split"); + + auto flightEndpointStr = + folly::base64Decode(flightSplit->flightEndpointBytes_); + + arrow::flight::FlightEndpoint flightEndpoint; + AFC_ASSIGN_OR_RAISE( + flightEndpoint, + arrow::flight::FlightEndpoint::Deserialize(flightEndpointStr)); + + arrow::flight::Location loc; + if (!flightEndpoint.locations.empty()) { + loc = flightEndpoint.locations[0]; + } else { + VELOX_CHECK_NOT_NULL( + defaultLocation_, + "No location from Flight endpoint, default host or port is missing"); + loc = *defaultLocation_; + } + + AFC_ASSIGN_OR_RAISE( + auto client, arrow::flight::FlightClient::Connect(loc, *clientOpts_)); + + CallOptionsAddHeaders callOptsAddHeaders{}; + authenticator_->authenticateClient( + client, connectorQueryCtx_->sessionProperties(), callOptsAddHeaders); + + auto readerResult = client->DoGet(callOptsAddHeaders, flightEndpoint.ticket); + AFC_ASSIGN_OR_RAISE(currentReader_, readerResult); +} + +std::optional ArrowFlightDataSource::next( + uint64_t size, + velox::ContinueFuture& /* unused */) { + VELOX_CHECK_NOT_NULL(currentReader_, "Missing split, call addSplit() first"); + + AFC_ASSIGN_OR_RAISE(auto chunk, currentReader_->Next()); + + // Null values in the chunk indicates that the Flight stream is complete. + if (!chunk.data) { + currentReader_ = nullptr; + return nullptr; + } + + // Extract only required columns from the record batch as a velox RowVector. + auto output = projectOutputColumns(chunk.data); + + completedRows_ += output->size(); + completedBytes_ += output->inMemoryBytes(); + return output; +} + +velox::RowVectorPtr ArrowFlightDataSource::projectOutputColumns( + const std::shared_ptr& input) { + velox::memory::MemoryPool* pool = connectorQueryCtx_->memoryPool(); + std::vector children; + children.reserve(columnMapping_.size()); + + // Extract and convert desired columns in the correct order. + for (const auto& name : columnMapping_) { + auto column = input->GetColumnByName(name); + VELOX_CHECK_NOT_NULL(column, "column with name '{}' not found", name); + ArrowArray array; + ArrowSchema schema; + AFC_RAISE_NOT_OK(arrow::ExportArray(*column, &array, &schema)); + children.push_back(velox::importFromArrowAsOwner(schema, array, pool)); + } + + return std::make_shared( + pool, + outputType_, + velox::BufferPtr() /*nulls*/, + input->num_rows(), + std::move(children)); +} + +std::unique_ptr +ArrowFlightConnector::createDataSource( + const velox::RowTypePtr& outputType, + const std::shared_ptr& tableHandle, + const std::unordered_map< + std::string, + std::shared_ptr>& columnHandles, + velox::connector::ConnectorQueryCtx* connectorQueryCtx) { + return std::make_unique( + outputType, + columnHandles, + authenticator_, + connectorQueryCtx, + flightConfig_, + clientOpts_); +} + +} // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h new file mode 100644 index 0000000000000..92d73a42420b4 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h @@ -0,0 +1,184 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConfig.h" +#include "presto_cpp/main/connectors/arrow_flight/auth/Authenticator.h" +#include "velox/connectors/Connector.h" + +namespace arrow { +class RecordBatch; +namespace flight { +class FlightClientOptions; +class FlightStreamReader; +class Location; +} // namespace flight +} // namespace arrow + +namespace facebook::presto { + +class ArrowFlightTableHandle : public velox::connector::ConnectorTableHandle { + public: + explicit ArrowFlightTableHandle(const std::string& connectorId) + : ConnectorTableHandle(connectorId) {} +}; + +struct ArrowFlightSplit : public velox::connector::ConnectorSplit { + /// @param connectorId + /// @param flightEndpointBytes Base64 Serialized `FlightEndpoint` + ArrowFlightSplit( + const std::string& connectorId, + const std::string& flightEndpointBytes) + : ConnectorSplit(connectorId), + flightEndpointBytes_(flightEndpointBytes) {} + + const std::string flightEndpointBytes_; +}; + +class ArrowFlightColumnHandle : public velox::connector::ColumnHandle { + public: + explicit ArrowFlightColumnHandle(const std::string& columnName) + : columnName_(columnName) {} + + const std::string& name() { + return columnName_; + } + + private: + std::string columnName_; +}; + +class ArrowFlightDataSource : public velox::connector::DataSource { + public: + ArrowFlightDataSource( + const velox::RowTypePtr& outputType, + const std::unordered_map< + std::string, + std::shared_ptr>& columnHandles, + std::shared_ptr authenticator, + const velox::connector::ConnectorQueryCtx* connectorQueryCtx, + const std::shared_ptr& flightConfig, + const std::shared_ptr& clientOpts); + + void addSplit( + std::shared_ptr split) override; + + std::optional next( + uint64_t size, + velox::ContinueFuture& /* unused */) override; + + void addDynamicFilter( + velox::column_index_t outputChannel, + const std::shared_ptr& filter) override { + VELOX_UNSUPPORTED("Arrow Flight connector doesn't support dynamic filters"); + } + + uint64_t getCompletedBytes() override { + return completedBytes_; + } + + uint64_t getCompletedRows() override { + return completedRows_; + } + + std::unordered_map runtimeStats() + override { + return {}; + } + + private: + /// Convert an Arrow record batch to Velox RowVector. + /// Process only those columns that are present in outputType_. + velox::RowVectorPtr projectOutputColumns( + const std::shared_ptr& input); + + velox::RowTypePtr outputType_; + std::vector columnMapping_; + std::unique_ptr currentReader_; + uint64_t completedRows_ = 0; + uint64_t completedBytes_ = 0; + std::shared_ptr authenticator_; + const velox::connector::ConnectorQueryCtx* const connectorQueryCtx_; + const std::shared_ptr flightConfig_; + const std::shared_ptr clientOpts_; + const std::shared_ptr defaultLocation_; +}; + +class ArrowFlightConnector : public velox::connector::Connector { + public: + explicit ArrowFlightConnector( + const std::string& id, + std::shared_ptr config, + const char* authenticatorName = nullptr) + : Connector(id), + flightConfig_(std::make_shared(config)), + clientOpts_(initClientOpts(flightConfig_)), + authenticator_(getAuthenticatorFactory( + authenticatorName + ? authenticatorName + : flightConfig_->authenticatorName()) + ->newAuthenticator(config)) {} + + std::unique_ptr createDataSource( + const velox::RowTypePtr& outputType, + const std::shared_ptr& + tableHandle, + const std::unordered_map< + std::string, + std::shared_ptr>& columnHandles, + velox::connector::ConnectorQueryCtx* connectorQueryCtx) override; + + std::unique_ptr createDataSink( + velox::RowTypePtr inputType, + std::shared_ptr + connectorInsertTableHandle, + velox::connector::ConnectorQueryCtx* connectorQueryCtx, + velox::connector::CommitStrategy commitStrategy) override { + VELOX_NYI("The arrow-flight connector does not support a DataSink"); + } + + private: + static std::shared_ptr initClientOpts( + const std::shared_ptr& config); + + const std::shared_ptr flightConfig_; + const std::shared_ptr clientOpts_; + const std::shared_ptr authenticator_; +}; + +class ArrowFlightConnectorFactory : public velox::connector::ConnectorFactory { + public: + static constexpr const char* kArrowFlightConnectorName = "arrow-flight"; + + ArrowFlightConnectorFactory() : ConnectorFactory(kArrowFlightConnectorName) {} + + explicit ArrowFlightConnectorFactory( + const char* name, + const char* authenticatorName = nullptr) + : ConnectorFactory(name), authenticatorName_(authenticatorName) {} + + std::shared_ptr newConnector( + const std::string& id, + std::shared_ptr config, + folly::Executor* ioExecutor = nullptr, + folly::Executor* cpuExecutor = nullptr) override { + return std::make_shared( + id, config, authenticatorName_); + } + + private: + const char* authenticatorName_{nullptr}; +}; + +} // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.cpp new file mode 100644 index 0000000000000..1ac5ab838f1db --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.cpp @@ -0,0 +1,63 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.h" +#include +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h" +#include "presto_cpp/presto_protocol/connector/arrow_flight/ArrowFlightConnectorProtocol.h" + +namespace facebook::presto { + +std::unique_ptr +ArrowPrestoToVeloxConnector::toVeloxSplit( + const protocol::ConnectorId& catalogId, + const protocol::ConnectorSplit* const connectorSplit, + const protocol::SplitContext* /*splitContext*/) const { + auto arrowSplit = + dynamic_cast(connectorSplit); + VELOX_CHECK_NOT_NULL( + arrowSplit, "Unexpected split type {}", connectorSplit->_type); + return std::make_unique( + catalogId, arrowSplit->flightEndpointBytes); +} + +std::unique_ptr +ArrowPrestoToVeloxConnector::toVeloxColumnHandle( + const protocol::ColumnHandle* column, + const TypeParser& /*typeParser*/) const { + auto arrowColumn = + dynamic_cast(column); + VELOX_CHECK_NOT_NULL( + arrowColumn, "Unexpected column handle type {}", column->_type); + return std::make_unique( + arrowColumn->columnName); +} + +std::unique_ptr +ArrowPrestoToVeloxConnector::toVeloxTableHandle( + const protocol::TableHandle& tableHandle, + const VeloxExprConverter& /*exprConverter*/, + const TypeParser& /*typeParser*/, + std::unordered_map< + std::string, + std::shared_ptr>& assignments) const { + return std::make_unique( + tableHandle.connectorId); +} + +std::unique_ptr +ArrowPrestoToVeloxConnector::createConnectorProtocol() const { + return std::make_unique(); +} + +} // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.h new file mode 100644 index 0000000000000..fa7ab67b9c0b7 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.h @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "presto_cpp/main/connectors/PrestoToVeloxConnector.h" + +namespace facebook::presto { + +class ArrowPrestoToVeloxConnector final : public PrestoToVeloxConnector { + public: + explicit ArrowPrestoToVeloxConnector(std::string connectorName) + : PrestoToVeloxConnector(std::move(connectorName)) {} + + std::unique_ptr toVeloxSplit( + const protocol::ConnectorId& catalogId, + const protocol::ConnectorSplit* connectorSplit, + const protocol::SplitContext* splitContext) const final; + + std::unique_ptr toVeloxColumnHandle( + const protocol::ColumnHandle* column, + const TypeParser& typeParser) const final; + + std::unique_ptr toVeloxTableHandle( + const protocol::TableHandle& tableHandle, + const VeloxExprConverter& exprConverter, + const TypeParser& typeParser, + std::unordered_map< + std::string, + std::shared_ptr>& assignments) + const final; + + std::unique_ptr createConnectorProtocol() + const final; +}; + +} // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/CMakeLists.txt b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/CMakeLists.txt new file mode 100644 index 0000000000000..c03b517c2518f --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/CMakeLists.txt @@ -0,0 +1,34 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +find_package(Arrow REQUIRED CONFIG) +find_package(ArrowFlight REQUIRED) + +add_subdirectory(auth) + +add_library(presto_flight_connector_utils INTERFACE Macros.h) +target_link_libraries(presto_flight_connector_utils INTERFACE velox_exception) + +add_library( + presto_flight_connector OBJECT + ArrowFlightConnector.cpp ArrowPrestoToVeloxConnector.cpp + ArrowFlightConfig.cpp) + +target_compile_definitions(presto_flight_connector + PUBLIC PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR) + +target_link_libraries( + presto_flight_connector velox_connector ArrowFlight::arrow_flight_shared + presto_flight_connector_utils presto_flight_connector_auth presto_types) + +if(PRESTO_ENABLE_TESTING) + add_subdirectory(tests) +endif() diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/Macros.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/Macros.h new file mode 100644 index 0000000000000..5ab725e582cc6 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/Macros.h @@ -0,0 +1,50 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/common/base/Exceptions.h" + +// Macros for dealing with arrow::Status and arrow::Result objects +// and converting them to velox exceptions. + +/// Raise a Velox exception if status is not OK. +/// Counterpart of ARROW_RETURN_NOT_OK. +#define AFC_RAISE_NOT_OK(status) \ + do { \ + ::arrow::Status __s = ::arrow::internal::GenericToStatus(status); \ + VELOX_CHECK(__s.ok(), __s.message()); \ + } while (false) + +#define AFC_ASSIGN_OR_RAISE_IMPL(result_name, lhs, rexpr) \ + auto&& result_name = (rexpr); \ + VELOX_CHECK((result_name).ok(), (result_name).status().message()); \ + lhs = std::move(result_name).ValueUnsafe(); + +/// Raise a Velox exception if expr doesn't return an OK result, +/// else unwrap the value and assign it to `lhs`. +/// `std::move`s its right hand operand. +/// Counterpart of ARROW_ASSIGN_OR_RAISE. +#define AFC_ASSIGN_OR_RAISE(lhs, rexpr) \ + AFC_ASSIGN_OR_RAISE_IMPL( \ + ARROW_ASSIGN_OR_RAISE_NAME(_error_or_value, __COUNTER__), lhs, rexpr); + +/// Raise a Velox exception if rexpr doesn't return an OK result, +/// else unwrap the value and return it. +/// `std::move`s its right hand operand. +#define AFC_RETURN_OR_RAISE(rexpr) \ + do { \ + auto&& __r = (rexpr); \ + VELOX_CHECK(__r.ok(), __r.status().message()); \ + return std::move(__r).ValueUnsafe(); \ + } while (false) diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/Authenticator.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/Authenticator.cpp new file mode 100644 index 0000000000000..78b6bf611b52c --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/Authenticator.cpp @@ -0,0 +1,49 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/auth/Authenticator.h" +#include +#include "velox/common/base/Exceptions.h" + +namespace facebook::presto { +namespace { +auto& authenticatorFactories() { + static std::unordered_map> + factories; + return factories; +} +} // namespace + +bool registerAuthenticatorFactory( + std::shared_ptr factory) { + bool ok = authenticatorFactories().insert({factory->name(), factory}).second; + VELOX_CHECK( + ok, + "Flight AuthenticatorFactory with name {} is already registered", + factory->name()); + return true; +} + +std::shared_ptr getAuthenticatorFactory( + const std::string& name) { + auto it = authenticatorFactories().find(name); + VELOX_CHECK( + it != authenticatorFactories().end(), + "Flight AuthenticatorFactory with name {} not registered", + name); + return it->second; +} + +AFC_REGISTER_AUTH_FACTORY(std::make_shared()) + +} // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/Authenticator.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/Authenticator.h new file mode 100644 index 0000000000000..bf44f8c3603ab --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/Authenticator.h @@ -0,0 +1,89 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/common/config/Config.h" + +namespace arrow::flight { +class AddCallHeaders; +class FlightClient; +} // namespace arrow::flight + +namespace facebook::presto { + +class Authenticator { + public: + /// @brief Override this method to define implementation-specific + /// authentication This could be through client->Authenticate, or + /// client->AuthenticateBasicToken or any other custom strategy + /// @param client the Flight client which is to be authenticated + /// @param sessionProperties connector session properties + /// @param headerWriter write-only object used to set authentication headers + virtual void authenticateClient( + std::unique_ptr& client, + const velox::config::ConfigBase* sessionProperties, + arrow::flight::AddCallHeaders& headerWriter) = 0; +}; + +class AuthenticatorFactory { + public: + explicit AuthenticatorFactory(std::string_view name) : name_{name} {} + + const std::string& name() const { + return name_; + } + + virtual std::shared_ptr newAuthenticator( + std::shared_ptr config) = 0; + + private: + std::string name_; +}; + +bool registerAuthenticatorFactory( + std::shared_ptr factory); + +std::shared_ptr getAuthenticatorFactory( + const std::string& name); + +#define AFC_REGISTER_AUTH_FACTORY(factory) \ + namespace { \ + static bool FB_ANONYMOUS_VARIABLE(g_ConnectorFactory) = \ + ::facebook::presto::registerAuthenticatorFactory((factory)); \ + } + +class NoOpAuthenticator : public Authenticator { + public: + void authenticateClient( + std::unique_ptr& client, + const velox::config::ConfigBase* sessionProperties, + arrow::flight::AddCallHeaders& headerWriter) override {} +}; + +class NoOpAuthenticatorFactory : public AuthenticatorFactory { + public: + static constexpr const std::string_view kNoOpAuthenticatorName{"none"}; + + NoOpAuthenticatorFactory() : AuthenticatorFactory{kNoOpAuthenticatorName} {} + + explicit NoOpAuthenticatorFactory(std::string_view name) + : AuthenticatorFactory{name} {} + + std::shared_ptr newAuthenticator( + std::shared_ptr config) override { + return std::make_shared(); + } +}; + +} // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/CMakeLists.txt b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/CMakeLists.txt new file mode 100644 index 0000000000000..1e7eba3154a0e --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/CMakeLists.txt @@ -0,0 +1,15 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +add_library(presto_flight_connector_auth Authenticator.cpp) + +target_link_libraries(presto_flight_connector_auth + presto_flight_connector_utils velox_exception) diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConfigTest.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConfigTest.cpp new file mode 100644 index 0000000000000..eb946f1fcae76 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConfigTest.cpp @@ -0,0 +1,48 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConfig.h" +#include + +using namespace facebook::velox; +using namespace facebook::presto; + +TEST(ArrowFlightConfigTest, defaultConfig) { + auto rawConfig = std::make_shared( + std::move(std::unordered_map{})); + auto config = ArrowFlightConfig(rawConfig); + ASSERT_EQ(config.authenticatorName(), "none"); + ASSERT_EQ(config.defaultServerHostname(), std::nullopt); + ASSERT_EQ(config.defaultServerPort(), std::nullopt); + ASSERT_EQ(config.defaultServerSslEnabled(), false); + ASSERT_EQ(config.serverVerify(), true); + ASSERT_EQ(config.serverSslCertificate(), std::nullopt); +} + +TEST(ArrowFlightConfigTest, overrideConfig) { + std::unordered_map configMap = { + {ArrowFlightConfig::kAuthenticatorName, "my-authenticator"}, + {ArrowFlightConfig::kDefaultServerHost, "my-server-host"}, + {ArrowFlightConfig::kDefaultServerPort, "9000"}, + {ArrowFlightConfig::kDefaultServerSslEnabled, "true"}, + {ArrowFlightConfig::kServerVerify, "false"}, + {ArrowFlightConfig::kServerSslCertificate, "my-cert.crt"}}; + auto config = ArrowFlightConfig( + std::make_shared(std::move(configMap))); + ASSERT_EQ(config.authenticatorName(), "my-authenticator"); + ASSERT_EQ(config.defaultServerHostname(), "my-server-host"); + ASSERT_EQ(config.defaultServerPort(), 9000); + ASSERT_EQ(config.defaultServerSslEnabled(), true); + ASSERT_EQ(config.serverVerify(), false); + ASSERT_EQ(config.serverSslCertificate(), "my-cert.crt"); +} diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorAuthTest.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorAuthTest.cpp new file mode 100644 index 0000000000000..e698ecd02948f --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorAuthTest.cpp @@ -0,0 +1,235 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h" +#include "presto_cpp/main/connectors/arrow_flight/Macros.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightConnectorTestBase.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightPlanBuilder.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.h" +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/config/Config.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" + +using namespace arrow; +using namespace facebook::velox; +using namespace facebook::velox::exec::test; + +namespace facebook::presto::test { + +class TestServerMiddlewareFactory : public flight::ServerMiddlewareFactory { + public: + static constexpr const char* kAuthHeader = "authorization"; + static constexpr const char* kAuthToken = "Bearer 1234"; + static constexpr const char* kAuthTokenUnauthorized = "Bearer 2112"; + + arrow::Status StartCall( + const flight::CallInfo& info, + const flight::ServerCallContext& context, + std::shared_ptr* middleware) override { + auto iter = context.incoming_headers().find(kAuthHeader); + + if (iter == context.incoming_headers().end()) { + return flight::MakeFlightError( + flight::FlightStatusCode::Unauthenticated, + "Authorization token not provided"); + } else { + std::lock_guard l(mutex_); + checkedTokens_.emplace_back(iter->second); + } + + if (kAuthToken != iter->second) { + return flight::MakeFlightError( + flight::FlightStatusCode::Unauthorized, + "Authorization token is invalid"); + } + + return arrow::Status::OK(); + } + + bool isTokenChecked(const std::string& authToken) { + { + std::lock_guard l(mutex_); + return std::find( + checkedTokens_.begin(), checkedTokens_.end(), authToken) != + checkedTokens_.end(); + } + } + + private: + std::string validToken_; + std::vector checkedTokens_; + std::mutex mutex_; +}; + +class TestAuthenticator : public Authenticator { + public: + explicit TestAuthenticator(const std::string& authToken) + : authToken_(authToken) {} + + void authenticateClient( + std::unique_ptr& client, + const velox::config::ConfigBase* sessionProperties, + arrow::flight::AddCallHeaders& headerWriter) override { + if (!authToken_.empty()) { + headerWriter.AddHeader( + TestServerMiddlewareFactory::kAuthHeader, authToken_); + } + } + + private: + std::string authToken_; +}; + +class TestAuthenticatorFactory : public AuthenticatorFactory { + public: + TestAuthenticatorFactory( + const std::string& name, + const std::string& authToken) + : AuthenticatorFactory(name), + testAuthenticator_{std::make_shared(authToken)} {} + + std::shared_ptr newAuthenticator( + std::shared_ptr config) override { + return testAuthenticator_; + } + + private: + std::shared_ptr testAuthenticator_; +}; + +namespace { +constexpr const char* kAuthFactoryName = "testing-auth-valid"; +constexpr const char* kAuthFactoryUnauthorizedName = + "testing-auth-unauthorized"; +constexpr const char* kAuthFactoryNoTokenName = "testing-auth-no-token"; + +bool registerTestAuthFactories() { + static bool once = [] { + auto authFactory = std::make_shared( + kAuthFactoryName, TestServerMiddlewareFactory::kAuthToken); + registerAuthenticatorFactory(authFactory); + auto authFactoryUnauthorized = std::make_shared( + kAuthFactoryUnauthorizedName, + TestServerMiddlewareFactory::kAuthTokenUnauthorized); + registerAuthenticatorFactory(authFactoryUnauthorized); + auto authFactoryNoToken = + std::make_shared(kAuthFactoryNoTokenName, ""); + registerAuthenticatorFactory(authFactoryNoToken); + return true; + }(); + return once; +} +} // namespace + +class ArrowFlightConnectorAuthTestBase : public ArrowFlightConnectorTestBase { + public: + explicit ArrowFlightConnectorAuthTestBase(const std::string& authFactoryName) + : ArrowFlightConnectorTestBase( + std::make_shared( + std::unordered_map{ + {ArrowFlightConfig::kAuthenticatorName, authFactoryName}})), + testMiddlewareFactory_( + std::make_shared()) {} + + void SetUp() override { + registerTestAuthFactories(); + ArrowFlightConnectorTestBase::SetUp(); + } + + void setFlightServerOptions( + flight::FlightServerOptions* serverOptions) override { + serverOptions->middleware.push_back( + {"bearer-auth-server", testMiddlewareFactory_}); + } + + core::PlanNodePtr addSampleDataAndRunQuery() { + updateTable( + "sample-data", + makeArrowTable( + {"id", "value"}, + {makeNumericArray( + {1, 12, 2, std::numeric_limits::max()}), + makeNumericArray( + {41, 42, 43, std::numeric_limits::min()})})); + + return ArrowFlightPlanBuilder() + .flightTableScan( + velox::ROW({"id", "value"}, {velox::BIGINT(), velox::INTEGER()})) + .planNode(); + } + + protected: + std::shared_ptr testMiddlewareFactory_; +}; + +class ArrowFlightConnectorAuthTest : public ArrowFlightConnectorAuthTestBase { + public: + ArrowFlightConnectorAuthTest() + : ArrowFlightConnectorAuthTestBase(kAuthFactoryName) {} +}; + +TEST_F(ArrowFlightConnectorAuthTest, customAuthenticator) { + core::PlanNodePtr plan = addSampleDataAndRunQuery(); + + auto idVec = + makeFlatVector({1, 12, 2, std::numeric_limits::max()}); + auto valueVec = makeFlatVector( + {41, 42, 43, std::numeric_limits::min()}); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({idVec, valueVec})); + + ASSERT_TRUE(testMiddlewareFactory_->isTokenChecked( + TestServerMiddlewareFactory::kAuthToken)); +} + +class ArrowFlightConnectorUnauthorizedTest + : public ArrowFlightConnectorAuthTestBase { + public: + ArrowFlightConnectorUnauthorizedTest() + : ArrowFlightConnectorAuthTestBase(kAuthFactoryUnauthorizedName) {} +}; + +TEST_F(ArrowFlightConnectorUnauthorizedTest, unauthorizedToken) { + core::PlanNodePtr plan = addSampleDataAndRunQuery(); + + VELOX_ASSERT_THROW( + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertEmptyResults(), + "Unauthorized"); +} + +class ArrowFlightConnectorUnauthenticatedTest + : public ArrowFlightConnectorAuthTestBase { + public: + ArrowFlightConnectorUnauthenticatedTest() + : ArrowFlightConnectorAuthTestBase(kAuthFactoryNoTokenName) {} +}; + +TEST_F(ArrowFlightConnectorUnauthenticatedTest, unauthenticatedNoToken) { + core::PlanNodePtr plan = addSampleDataAndRunQuery(); + + VELOX_ASSERT_THROW( + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertEmptyResults(), + "Unauthenticated"); +} + +} // namespace facebook::presto::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorDataTypeTest.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorDataTypeTest.cpp new file mode 100644 index 0000000000000..257497caf224d --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorDataTypeTest.cpp @@ -0,0 +1,506 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h" +#include "presto_cpp/main/connectors/arrow_flight/Macros.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightConnectorTestBase.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightPlanBuilder.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.h" +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" + +using namespace arrow; +using namespace facebook::velox; +using namespace facebook::velox::exec::test; + +namespace facebook::presto::test { + +class ArrowFlightConnectorDataTypeTest : public ArrowFlightConnectorTestBase {}; + +TEST_F(ArrowFlightConnectorDataTypeTest, booleanType) { + updateTable( + "sample-data", + makeArrowTable( + {"bool_col"}, {makeBooleanArray({true, false, true, false})})); + + auto boolVec = makeFlatVector({true, false, true, false}); + + core::PlanNodePtr plan; + plan = ArrowFlightPlanBuilder() + .flightTableScan(velox::ROW({"bool_col"}, {velox::BOOLEAN()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({boolVec})); +} + +TEST_F(ArrowFlightConnectorDataTypeTest, integerTypes) { + std::vector tinyData = { + -128, 0, 127, std::numeric_limits::max()}; + std::vector smallData = { + -32768, 0, 32767, std::numeric_limits::max()}; + std::vector intData = { + -2147483648, 0, 2147483647, std::numeric_limits::max()}; + std::vector bigData = { + -3435678987654321234LL, + 0, + 4527897896541234567LL, + std::numeric_limits::max()}; + + updateTable( + "sample-data", + makeArrowTable( + {"tinyint_col", "smallint_col", "integer_col", "bigint_col"}, + {makeNumericArray(tinyData), + makeNumericArray(smallData), + makeNumericArray(intData), + makeNumericArray(bigData)})); + + auto tinyintVec = makeFlatVector(tinyData); + auto smallintVec = makeFlatVector(smallData); + auto integerVec = makeFlatVector(intData); + auto bigintVec = makeFlatVector(bigData); + + core::PlanNodePtr plan; + plan = ArrowFlightPlanBuilder() + .flightTableScan(velox::ROW( + {"tinyint_col", "smallint_col", "integer_col", "bigint_col"}, + {velox::TINYINT(), + velox::SMALLINT(), + velox::INTEGER(), + velox::BIGINT()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults( + makeRowVector({tinyintVec, smallintVec, integerVec, bigintVec})); +} + +TEST_F(ArrowFlightConnectorDataTypeTest, realType) { + std::vector realData = { + std::numeric_limits::min(), + 0.0f, + 3.14f, + std::numeric_limits::max()}; + std::vector doubleData = { + std::numeric_limits::min(), + 0.0, + 3.14159, + std::numeric_limits::max()}; + + updateTable( + "sample-data", + makeArrowTable( + {"real_col", "double_col"}, + {makeNumericArray(realData), + makeNumericArray(doubleData)})); + + auto realVec = makeFlatVector(realData); + auto doubleVec = makeFlatVector(doubleData); + + core::PlanNodePtr plan; + plan = ArrowFlightPlanBuilder() + .flightTableScan(velox::ROW( + {"real_col", "double_col"}, {velox::REAL(), velox::DOUBLE()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({realVec, doubleVec})); +} + +TEST_F(ArrowFlightConnectorDataTypeTest, varcharType) { + std::vector data = { + "Hello", + "World", + "India", + "Hello World", // Inlined + "Hello World!", // Not inlined + "HelloWorldIndia", + "HelloWorldIndia!!!"}; + + updateTable( + "sample-data", makeArrowTable({"varchar_col"}, {makeStringArray(data)})); + + auto vec = + makeFlatVector(makeStringViewVector(data)); + + core::PlanNodePtr plan; + plan = ArrowFlightPlanBuilder() + .flightTableScan(velox::ROW({"varchar_col"}, {velox::VARCHAR()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({vec})); +} + +TEST_F(ArrowFlightConnectorDataTypeTest, varcharSpecialChars) { + std::vector data = { + "Hello", + "WORLD", + "hi there world", + " there there", + "hello \"world\"", + "hello_#@,$|%/^~?{}+-", + "city.id@address:number/date|day$a-b$10_bucket", + "café", + "abc\\x00def", + "日本語"}; + + updateTable( + "sample-data", makeArrowTable({"varchar_col"}, {makeStringArray(data)})); + + auto vec = + makeFlatVector(makeStringViewVector(data)); + + core::PlanNodePtr plan; + plan = ArrowFlightPlanBuilder() + .flightTableScan(velox::ROW({"varchar_col"}, {velox::VARCHAR()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({vec})); +} + +TEST_F(ArrowFlightConnectorDataTypeTest, varbinaryType) { + std::vector data = {"abc", "defghijk", "lmnopqrstuvwxyz"}; + + updateTable( + "sample-data", + makeArrowTable({"varbinary_col"}, {makeBinaryArray(data)})); + + auto vec = makeFlatVector( + makeStringViewVector(data), velox::VARBINARY()); + + core::PlanNodePtr plan; + plan = + ArrowFlightPlanBuilder() + .flightTableScan(velox::ROW({"varbinary_col"}, {velox::VARBINARY()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({vec})); +} + +TEST_F(ArrowFlightConnectorDataTypeTest, timestampType) { + auto timestampValues = + std::vector{1622538000, 1622541600, 1622545200}; + + updateTable( + "sample-data", + makeArrowTable( + {"timestampsec_col", "timestampmilli_col", "timestampmicro_col"}, + {makeTimestampArray(timestampValues, arrow::TimeUnit::SECOND), + makeTimestampArray(timestampValues, arrow::TimeUnit::MILLI), + makeTimestampArray(timestampValues, arrow::TimeUnit::MICRO)})); + + std::vector veloxTimestampSec; + for (const auto& ts : timestampValues) { + veloxTimestampSec.emplace_back(ts, 0); // Assuming 0 microseconds part + } + + auto timestampSecCol = + makeFlatVector(veloxTimestampSec); + + std::vector veloxTimestampMilli; + for (const auto& ts : timestampValues) { + veloxTimestampMilli.emplace_back( + ts / 1000, (ts % 1000) * 1000000); // Convert to seconds and nanoseconds + } + + auto timestampMilliCol = + makeFlatVector(veloxTimestampMilli); + + std::vector veloxTimestampMicro; + for (const auto& ts : timestampValues) { + veloxTimestampMicro.emplace_back( + ts / 1000000, + (ts % 1000000) * 1000); // Convert to seconds and nanoseconds + } + + auto timestampMicroCol = + makeFlatVector(veloxTimestampMicro); + + core::PlanNodePtr plan; + plan = + ArrowFlightPlanBuilder() + .flightTableScan(velox::ROW( + {"timestampsec_col", "timestampmilli_col", "timestampmicro_col"}, + {velox::TIMESTAMP(), velox::TIMESTAMP(), velox::TIMESTAMP()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector( + {timestampSecCol, timestampMilliCol, timestampMicroCol})); +} + +TEST_F(ArrowFlightConnectorDataTypeTest, dateDayType) { + std::vector datesDay = {18748, 18749, 18750}; // Days since epoch + std::vector datesMilli = { + 1622538000000, 1622541600000, 1622545200000}; // Milliseconds since epoch + + updateTable( + "sample-data", + makeArrowTable( + {"daydate_col", "daymilli_col"}, + {makeNumericArray(datesDay), + makeNumericArray(datesMilli)})); + + auto dateVec = makeFlatVector(datesDay); + auto milliVec = makeFlatVector(datesMilli); + + core::PlanNodePtr plan; + plan = ArrowFlightPlanBuilder() + .flightTableScan(velox::ROW({"daydate_col"}, {velox::DATE()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({dateVec})); + + plan = ArrowFlightPlanBuilder() + .flightTableScan(velox::ROW({"daymilli_col"}, {velox::DATE()})) + .planNode(); + + VELOX_ASSERT_THROW( + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({milliVec})), + "Unable to convert 'tdm' ArrowSchema format type to Velox"); +} + +TEST_F(ArrowFlightConnectorDataTypeTest, decimalType) { + std::vector decimalValuesBigInt = { + 123456789012345678, + -123456789012345678, + std::numeric_limits::max()}; + std::vector> decimalArrayVec; + decimalArrayVec.push_back(makeDecimalArray(decimalValuesBigInt, 18, 2)); + updateTable( + "sample-data", makeArrowTable({"decimal_col_bigint"}, decimalArrayVec)); + auto decimalVecBigInt = makeFlatVector(decimalValuesBigInt); + + core::PlanNodePtr plan; + plan = ArrowFlightPlanBuilder() + .flightTableScan(velox::ROW( + {"decimal_col_bigint"}, + {velox::DECIMAL(18, 2)})) // precision can't be 0 and < scale + .planNode(); + + // Execute the query and assert the results + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({decimalVecBigInt})); +} + +TEST_F(ArrowFlightConnectorDataTypeTest, arrayType) { + std::vector>>> data = { + std::nullopt, {{1, 2}}, {{3, std::nullopt, 4}}, {{5, 6, 7, 8, 9}}}; + + auto intBuilder = std::make_shared>(); + arrow::ListBuilder listBuilder(arrow::default_memory_pool(), intBuilder); + + for (const auto& array : data) { + AFC_RAISE_NOT_OK(listBuilder.Append(array.has_value())); + if (array.has_value()) { + const auto& elements = array.value(); + for (const auto& element : elements) { + if (element.has_value()) { + AFC_RAISE_NOT_OK(intBuilder->Append(element.value())); + } else { + AFC_RAISE_NOT_OK(intBuilder->AppendNull()); + } + } + } + } + AFC_ASSIGN_OR_RAISE(auto listArray, listBuilder.Finish()); + + updateTable("sample-data", makeArrowTable({"int_array_col"}, {listArray})); + + core::PlanNodePtr plan; + plan = ArrowFlightPlanBuilder() + .flightTableScan(velox::ROW( + {"int_array_col"}, {velox::ARRAY(velox::INTEGER())})) + .planNode(); + + auto expectedData = makeNullableArrayVector(data); + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({expectedData})); +} + +TEST_F(ArrowFlightConnectorDataTypeTest, mapType) { + auto data = std::vector< + std::optional>>>>({ + {{{0, 100}, {1, 101}, {2, 102}}}, + {{{std::numeric_limits::max(), + std::numeric_limits::max()}, + {std::numeric_limits::min(), + std::numeric_limits::min()}}}, + {{}}, + {{{42, std::nullopt}}}, + std::nullopt, + {{{3, -300}, + {4, 400}, + {5, -500}, + {6, 600}, + {7, -700}, + {8, 800}, + {9, -900}}}, + }); + + auto keyBuilder = std::make_shared>(); + auto itemBuilder = + std::make_shared>(); + arrow::MapBuilder mapBuilder( + arrow::default_memory_pool(), keyBuilder, itemBuilder); + + for (const auto& mapElements : data) { + if (mapElements.has_value()) { + AFC_RAISE_NOT_OK(mapBuilder.Append()); + const auto& pairs = mapElements.value(); + for (const auto& key_item : pairs) { + AFC_RAISE_NOT_OK(keyBuilder->Append(key_item.first)); + if (key_item.second.has_value()) { + AFC_RAISE_NOT_OK(itemBuilder->Append(key_item.second.value())); + } else { + AFC_RAISE_NOT_OK(itemBuilder->AppendNull()); + } + } + } else { + AFC_RAISE_NOT_OK(mapBuilder.AppendNull()); + } + } + AFC_ASSIGN_OR_RAISE(auto mapArray, mapBuilder.Finish()); + + updateTable("sample-data", makeArrowTable({"map_col"}, {mapArray})); + + core::PlanNodePtr plan; + plan = ArrowFlightPlanBuilder() + .flightTableScan(velox::ROW( + {"map_col"}, {velox::MAP(velox::INTEGER(), velox::BIGINT())})) + .planNode(); + + auto expectedData = makeNullableMapVector(data); + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({expectedData})); +} + +TEST_F(ArrowFlightConnectorDataTypeTest, rowType) { + std::vector intData = {0, 1, 2, 3, 4}; + std::vector varcharData = {"a", "bb", "ccc", "dddd", "eeeee"}; + std::vector doubleData = {0.0, 1.1, 2.2, 3.3, 4.4}; + + auto recordBatch = makeRecordBatch( + {"int_col", "varchar_col", "double_col"}, + {makeNumericArray(intData), + makeStringArray(varcharData), + makeNumericArray(doubleData)}); + AFC_ASSIGN_OR_RAISE(auto structArray, recordBatch->ToStructArray()); + + updateTable("sample-data", makeArrowTable({"row_col"}, {structArray})); + + core::PlanNodePtr plan; + plan = ArrowFlightPlanBuilder() + .flightTableScan(velox::ROW( + {"row_col"}, + {velox::ROW( + {"int_col", "varchar_col", "double_col"}, + {velox::INTEGER(), velox::VARCHAR(), velox::DOUBLE()})})) + .planNode(); + + auto expectedData = makeRowVector( + {makeFlatVector(intData), + makeFlatVector(varcharData), + makeFlatVector(doubleData)}); + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({expectedData})); +} + +TEST_F(ArrowFlightConnectorDataTypeTest, allTypes) { + auto timestampValues = + std::vector{1622550000, 1622553600, 1622557200}; + + auto sampleTable = makeArrowTable( + {"id", + "daydate_col", + "timestamp_col", + "varchar_col", + "real_col", + "int_col", + "bool_col"}, + {makeNumericArray({1, 2, 3}), + makeNumericArray({18748, 18749, 18750}), + makeTimestampArray(timestampValues, arrow::TimeUnit::SECOND), + makeStringArray({"apple", "banana", "cherry"}), + makeNumericArray({3.14, 2.718, 1.618}), + makeNumericArray( + {-32768, 32767, std::numeric_limits::max()}), + makeBooleanArray({true, false, true})}); + + updateTable("gen-data", sampleTable); + + auto dateVec = makeFlatVector({18748, 18749, 18750}); + + std::vector veloxTimestampSec; + for (const auto& ts : timestampValues) { + veloxTimestampSec.emplace_back(ts, 0); // Assuming 0 microseconds part + } + auto timestampSecVec = + makeFlatVector(veloxTimestampSec); + + auto stringVec = makeFlatVector( + {facebook::velox::StringView("apple"), + facebook::velox::StringView("banana"), + facebook::velox::StringView("cherry")}); + auto realVec = makeFlatVector({3.14, 2.718, 1.618}); + auto intVec = makeFlatVector( + {-32768, 32767, std::numeric_limits::max()}); + auto boolVec = makeFlatVector({true, false, true}); + + core::PlanNodePtr plan; + plan = ArrowFlightPlanBuilder() + .flightTableScan(velox::ROW( + {"daydate_col", + "timestamp_col", + "varchar_col", + "real_col", + "int_col", + "bool_col"}, + {velox::DATE(), + velox::TIMESTAMP(), + velox::VARCHAR(), + velox::DOUBLE(), + velox::INTEGER(), + velox::BOOLEAN()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"gen-data"})) + .assertResults(makeRowVector( + {dateVec, timestampSecVec, stringVec, realVec, intVec, boolVec})); +} + +} // namespace facebook::presto::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorTest.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorTest.cpp new file mode 100644 index 0000000000000..acb41c5087ea1 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorTest.cpp @@ -0,0 +1,284 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h" +#include +#include +#include +#include +#include "presto_cpp/main/connectors/arrow_flight/Macros.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightConnectorTestBase.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightPlanBuilder.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.h" +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/config/Config.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/PortUtil.h" + +using namespace arrow; +using namespace facebook::velox; +using namespace facebook::velox::exec::test; + +namespace facebook::presto::test { + +class ArrowFlightConnectorTest : public ArrowFlightConnectorTestBase {}; + +TEST_F(ArrowFlightConnectorTest, invalidSplit) { + auto plan = ArrowFlightPlanBuilder() + .flightTableScan(velox::ROW({{"id", velox::BIGINT()}})) + .planNode(); + + VELOX_ASSERT_THROW( + velox::exec::test::AssertQueryBuilder(plan) + .splits(makeSplits({"unknown"})) + .copyResults(pool()), + "table does not exist"); +} + +TEST_F(ArrowFlightConnectorTest, dataSourceCreation) { + // missing columnHandle test + auto plan = + ArrowFlightPlanBuilder() + .flightTableScan( + velox::ROW({"id", "value"}, {velox::BIGINT(), velox::INTEGER()}), + {{"id", std::make_shared("id")}}, + false /*createDefaultColumnHandles*/) + .planNode(); + + VELOX_ASSERT_THROW( + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .copyResults(pool()), + "missing columnHandle for column 'value'"); +} + +TEST_F(ArrowFlightConnectorTest, dataSource) { + std::vector idData = {1, 12, 2, std::numeric_limits::max()}; + std::vector valueData = { + 41, 42, 43, std::numeric_limits::min()}; + std::vector unsignedData = { + 41, 42, 43, std::numeric_limits::min()}; + + updateTable( + "sample-data", + makeArrowTable( + {"id", "value", "unsigned"}, + {makeNumericArray(idData), + makeNumericArray(valueData), + // note that velox doesn't support unsigned types + // connector should still be able to query such tables + // as long as this specific column isn't requested. + makeNumericArray(unsignedData)})); + + auto idColumn = std::make_shared("id"); + auto idVec = makeFlatVector(idData); + + auto valueColumn = std::make_shared("value"); + auto valueVec = makeFlatVector(valueData); + + core::PlanNodePtr plan; + + // direct test + plan = ArrowFlightPlanBuilder() + .flightTableScan(velox::ROW( + {"id", "value"}, {velox::BIGINT(), velox::INTEGER()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({idVec, valueVec})); + + VELOX_ASSERT_THROW( + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"}, std::vector{})) + .assertResults(makeRowVector({idVec, valueVec})), + "default host or port is missing"); + + // column alias test + plan = + ArrowFlightPlanBuilder() + .flightTableScan( + velox::ROW({"ducks", "id"}, {velox::BIGINT(), velox::BIGINT()}), + {{"ducks", idColumn}}) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({idVec, idVec})); + + // invalid columnHandle test + plan = ArrowFlightPlanBuilder() + .flightTableScan(velox::ROW( + {"ducks", "value"}, {velox::BIGINT(), velox::INTEGER()})) + .planNode(); + + VELOX_ASSERT_THROW( + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .copyResults(pool()), + "column with name 'ducks' not found"); +} + +TEST_F(ArrowFlightConnectorTest, multipleBatches) { + vector_size_t numValues = 100; + int64_t batchSize = 30; + std::string letters = "abcdefghijklmnopqrstuvwxyz"; + std::vector intData(numValues); + std::vector varcharData(numValues); + std::vector doubleData(numValues); + + for (vector_size_t i = 0; i < numValues; i++) { + intData[i] = i; + size_t pos = i % letters.size(); + size_t len = std::min(i % 5, letters.size() - pos); + varcharData[i] = letters.substr(pos, len); + doubleData[i] = i * 1.1; + } + + updateTable( + "sample-data", + makeArrowTable( + {"int_col", "varchar_col", "double_col"}, + {makeNumericArray(intData), + makeStringArray(varcharData), + makeNumericArray(doubleData)})); + setBatchSize(batchSize); + + core::PlanNodePtr plan; + plan = ArrowFlightPlanBuilder() + .flightTableScan(velox::ROW( + {"int_col", "varchar_col", "double_col"}, + {velox::INTEGER(), velox::VARCHAR(), velox::DOUBLE()})) + .planNode(); + + auto intVec = makeFlatVector(intData); + auto varcharVec = makeFlatVector(varcharData); + auto doubleVec = makeFlatVector(doubleData); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({intVec, varcharVec, doubleVec})); +} + +TEST_F(ArrowFlightConnectorTest, multipleSplits) { + vector_size_t numValues = 100; + int64_t batchSize = 30; + std::string letters = "abcdefghijklmnopqrstuvwxyz"; + std::vector intData(numValues); + std::vector varcharData(numValues); + std::vector doubleData(numValues); + + for (vector_size_t i = 0; i < numValues; i++) { + intData[i] = i; + size_t pos = i % letters.size(); + size_t len = std::min(i % 5, letters.size() - pos); + varcharData[i] = letters.substr(pos, len); + doubleData[i] = i * 1.1; + } + + updateTable( + "sample-data-1", + makeArrowTable( + {"int_col", "varchar_col", "double_col"}, + {makeNumericArray(intData), + makeStringArray(varcharData), + makeNumericArray(doubleData)})); + updateTable( + "sample-data-2", + makeArrowTable( + {"int_col", "varchar_col", "double_col"}, + {makeNumericArray(intData), + makeStringArray(varcharData), + makeNumericArray(doubleData)})); + setBatchSize(batchSize); + + core::PlanNodePtr plan; + plan = ArrowFlightPlanBuilder() + .flightTableScan(velox::ROW( + {"int_col", "varchar_col", "double_col"}, + {velox::INTEGER(), velox::VARCHAR(), velox::DOUBLE()})) + .planNode(); + + std::vector intDataAll(intData); + intDataAll.insert(intDataAll.begin(), intData.begin(), intData.end()); + std::vector varcharDataAll(varcharData); + varcharDataAll.insert( + varcharDataAll.end(), varcharData.begin(), varcharData.end()); + std::vector doubleDataAll(doubleData); + doubleDataAll.insert( + doubleDataAll.end(), doubleData.begin(), doubleData.end()); + auto intVec = makeFlatVector(intDataAll); + auto varcharVec = makeFlatVector(varcharDataAll); + auto doubleVec = makeFlatVector(doubleDataAll); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data-1", "sample-data-2"})) + .assertResults(makeRowVector({intVec, varcharVec, doubleVec})); +} + +class ArrowFlightConnectorTestDefaultServer + : public ArrowFlightConnectorTestBase { + public: + ArrowFlightConnectorTestDefaultServer() + : ArrowFlightConnectorTestBase( + std::make_shared( + std::unordered_map{ + {ArrowFlightConfig::kDefaultServerHost, kConnectHost}, + {ArrowFlightConfig::kDefaultServerPort, + std::to_string(getFreePort())}})) {} +}; + +TEST_F(ArrowFlightConnectorTestDefaultServer, dataSource) { + std::vector idData = {1, 12, 2, std::numeric_limits::max()}; + std::vector valueData = { + 41, 42, 43, std::numeric_limits::min()}; + + updateTable( + "sample-data", + makeArrowTable( + {"id", "value"}, + {makeNumericArray(idData), + makeNumericArray(valueData)})); + + auto idColumn = std::make_shared("id"); + auto idVec = makeFlatVector(idData); + + auto valueColumn = std::make_shared("value"); + auto valueVec = makeFlatVector(valueData); + + core::PlanNodePtr plan; + + // direct test + plan = ArrowFlightPlanBuilder() + .flightTableScan(velox::ROW( + {"id", "value"}, {velox::BIGINT(), velox::INTEGER()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({idVec, valueVec})); + + AssertQueryBuilder(plan) + .splits(makeSplits( + {"sample-data"}, + std::vector{})) // Using default connector + .assertResults(makeRowVector({idVec, valueVec})); +} + +} // namespace facebook::presto::test + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + folly::Init init{&argc, &argv, false}; + return RUN_ALL_TESTS(); +} diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorTlsTest.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorTlsTest.cpp new file mode 100644 index 0000000000000..4453183a39412 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorTlsTest.cpp @@ -0,0 +1,122 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightConnectorTestBase.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightPlanBuilder.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.h" +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/config/Config.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" + +using namespace arrow; +using namespace facebook::velox; +using namespace facebook::velox::exec::test; + +namespace facebook::presto::test { + +class ArrowFlightConnectorTlsTestBase : public ArrowFlightConnectorTestBase { + protected: + explicit ArrowFlightConnectorTlsTestBase( + std::shared_ptr config) + : ArrowFlightConnectorTestBase(std::move(config)) {} + + void setFlightServerOptions( + flight::FlightServerOptions* serverOptions) override { + flight::CertKeyPair tlsCertificate{ + .pem_cert = readFile("./data/tls_certs/server.crt"), + .pem_key = readFile("./data/tls_certs/server.key")}; + serverOptions->tls_certificates.push_back(tlsCertificate); + } + + void executeTest( + bool isPositiveTest = true, + const std::string& expectedError = "") { + std::vector idData = { + 1, 12, 2, std::numeric_limits::max()}; + + updateTable( + "sample-data", + makeArrowTable({"id"}, {makeNumericArray(idData)})); + + auto idVec = makeFlatVector(idData); + + auto plan = ArrowFlightPlanBuilder() + .flightTableScan(velox::ROW({"id"}, {velox::BIGINT()})) + .planNode(); + + auto runQuery = [&]() { + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({idVec})); + }; + + if (isPositiveTest) { + runQuery(); + } else { + VELOX_ASSERT_THROW(runQuery(), expectedError); + } + } +}; + +class ArrowFlightConnectorTlsTest : public ArrowFlightConnectorTlsTestBase { + protected: + explicit ArrowFlightConnectorTlsTest() + : ArrowFlightConnectorTlsTestBase( + std::make_shared( + std::unordered_map{ + {ArrowFlightConfig::kDefaultServerSslEnabled, "true"}, + {ArrowFlightConfig::kServerVerify, "true"}, + {ArrowFlightConfig::kServerSslCertificate, + "./data/tls_certs/ca.crt"}})) {} +}; + +TEST_F(ArrowFlightConnectorTlsTest, tlsEnabled) { + executeTest(); +} + +class ArrowFlightTlsNoCertValidationTest + : public ArrowFlightConnectorTlsTestBase { + protected: + explicit ArrowFlightTlsNoCertValidationTest() + : ArrowFlightConnectorTlsTestBase( + std::make_shared( + std::unordered_map{ + {ArrowFlightConfig::kDefaultServerSslEnabled, "true"}, + {ArrowFlightConfig::kServerVerify, "false"}})) {} +}; + +TEST_F(ArrowFlightTlsNoCertValidationTest, tlsNoCertValidation) { + executeTest(); +} + +class ArrowFlightTlsNoCertTest : public ArrowFlightConnectorTlsTestBase { + protected: + ArrowFlightTlsNoCertTest() + : ArrowFlightConnectorTlsTestBase( + std::make_shared( + std::unordered_map{ + {ArrowFlightConfig::kDefaultServerSslEnabled, "true"}, + {ArrowFlightConfig::kServerVerify, "true"}})) {} +}; + +TEST_F(ArrowFlightTlsNoCertTest, tlsNoCert) { + executeTest(false, "handshake failed"); +} + +} // namespace facebook::presto::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/CMakeLists.txt b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/CMakeLists.txt new file mode 100644 index 0000000000000..9af596a913973 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/CMakeLists.txt @@ -0,0 +1,45 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +add_subdirectory(utils) + +add_executable(presto_flight_connector_infra_test + TestingArrowFlightServerTest.cpp) + +add_test(presto_flight_connector_infra_test presto_flight_connector_infra_test) + +target_link_libraries( + presto_flight_connector_infra_test presto_protocol + presto_flight_connector_test_lib GTest::gtest GTest::gtest_main ${GLOG}) + +add_executable( + presto_flight_connector_test + ArrowFlightConnectorTest.cpp ArrowFlightConnectorAuthTest.cpp + ArrowFlightConnectorTlsTest.cpp ArrowFlightConnectorDataTypeTest.cpp + ArrowFlightConfigTest.cpp) + +set(DATA_DIR "${CMAKE_CURRENT_SOURCE_DIR}/data/tls_certs") + +add_custom_target( + copy_flight_test_data ALL + COMMAND ${CMAKE_COMMAND} -E copy_directory ${DATA_DIR} + $/data/tls_certs) + +add_test(presto_flight_connector_test presto_flight_connector_test) + +target_link_libraries( + presto_flight_connector_test + velox_exec_test_lib + presto_flight_connector + gtest + gtest_main + presto_flight_connector_test_lib + presto_protocol) diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/TestingArrowFlightServerTest.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/TestingArrowFlightServerTest.cpp new file mode 100644 index 0000000000000..9a7a95201a4da --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/TestingArrowFlightServerTest.cpp @@ -0,0 +1,83 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/TestingArrowFlightServer.h" +#include +#include +#include +#include +#include +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.h" + +using namespace arrow; + +namespace facebook::presto::test { + +class TestingArrowFlightServerTest : public testing::Test { + public: + static void SetUpTestSuite() { + server = std::make_unique(); + ASSERT_OK_AND_ASSIGN( + auto loc, flight::Location::ForGrpcTcp("127.0.0.1", 0)); + ASSERT_OK(server->Init(flight::FlightServerOptions(loc))); + } + + static void TearDownTestSuite() { + ASSERT_OK(server->Shutdown()); + } + + static void updateTable( + const std::string& name, + const std::shared_ptr& table) { + server->updateTable(name, table); + } + + void SetUp() override { + ASSERT_OK_AND_ASSIGN( + auto loc, flight::Location::ForGrpcTcp("localhost", server->port())); + ASSERT_OK_AND_ASSIGN(client_, flight::FlightClient::Connect(loc)); + } + + std::unique_ptr client_; + static std::unique_ptr server; +}; + +std::unique_ptr TestingArrowFlightServerTest::server; + +TEST_F(TestingArrowFlightServerTest, basicClientConnection) { + auto sampleTable = makeArrowTable( + {"id", "value"}, + {makeNumericArray({1, 2}), + makeNumericArray({41, 42})}); + updateTable("sample-data", sampleTable); + + ASSERT_RAISES(KeyError, client_->DoGet(flight::Ticket{"empty"})); + + auto emptyTable = makeArrowTable({}, {}); + updateTable("empty", emptyTable); + + ASSERT_RAISES(KeyError, client_->DoGet(flight::Ticket{"non-existent-table"})); + + ASSERT_OK_AND_ASSIGN(auto reader, client_->DoGet(flight::Ticket{"empty"})); + ASSERT_OK_AND_ASSIGN(auto actual, reader->ToTable()); + EXPECT_TRUE(actual->Equals(*emptyTable)); + + ASSERT_OK_AND_ASSIGN(reader, client_->DoGet(flight::Ticket{"sample-data"})); + ASSERT_OK_AND_ASSIGN(actual, reader->ToTable()); + EXPECT_TRUE(actual->Equals(*sampleTable)); + + server->removeTable("sample-data"); + ASSERT_RAISES(KeyError, client_->DoGet(flight::Ticket{"sample-data"})); +} + +} // namespace facebook::presto::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/README.md b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/README.md new file mode 100644 index 0000000000000..3a5f2e5786c67 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/README.md @@ -0,0 +1,7 @@ +### Placeholder TLS Certificates for Arrow Flight Connector Unit Testing +The `tls_certs` directory contains placeholder TLS certificates generated for unit testing the Arrow Flight Connector with TLS enabled. These certificates are not intended for production use and should only be used in the context of unit tests. + +### Generating TLS Certificates +To create the TLS certificates and keys inside the `tls_certs` folder, run the following command: + +`./generate_tls_certs.sh` diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/generate_tls_certs.sh b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/generate_tls_certs.sh new file mode 100755 index 0000000000000..718f313c70a75 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/generate_tls_certs.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +# Set directory for certificates and keys. +CERT_DIR="./tls_certs" +mkdir -p $CERT_DIR + +# Dummy values for the certificates. +COUNTRY="US" +STATE="State" +LOCALITY="City" +ORGANIZATION="MyOrg" +ORG_UNIT="MyUnit" +COMMON_NAME="MyCA" +SERVER_CN="server.mydomain.com" + +# Step 1: Generate CA private key and self-signed certificate. +openssl genpkey -algorithm RSA -out $CERT_DIR/ca.key +openssl req -key $CERT_DIR/ca.key -new -x509 -out $CERT_DIR/ca.crt -days 365000 \ + -subj "/C=$COUNTRY/ST=$STATE/L=$LOCALITY/O=$ORGANIZATION/OU=$ORG_UNIT/CN=$COMMON_NAME" + +# Step 2: Generate server private key. +openssl genpkey -algorithm RSA -out $CERT_DIR/server.key + +# Step 3: Generate server certificate signing request (CSR). +openssl req -new -key $CERT_DIR/server.key -out $CERT_DIR/server.csr \ + -subj "/C=$COUNTRY/ST=$STATE/L=$LOCALITY/O=$ORGANIZATION/OU=$ORG_UNIT/CN=$SERVER_CN" \ + -addext "subjectAltName=DNS:$COMMON_NAME,DNS:localhost" \ + +# Step 4: Sign server CSR with the CA certificate to generate the server certificate. +openssl x509 -req -in $CERT_DIR/server.csr -CA $CERT_DIR/ca.crt -CAkey $CERT_DIR/ca.key \ + -CAcreateserial -out $CERT_DIR/server.crt -days 365000 \ + -extfile <(printf "subjectAltName=DNS:$COMMON_NAME,DNS:localhost") + +# Step 5: Output the generated files. +echo "Certificate Authority (CA) certificate: $CERT_DIR/ca.crt" +echo "Server certificate: $CERT_DIR/server.crt" +echo "Server private key: $CERT_DIR/server.key" + +# Step 6: Remove unused files. +rm -rf $CERT_DIR/server.csr $CERT_DIR/ca.srl $CERT_DIR/ca.key diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/ca.crt b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/ca.crt new file mode 100644 index 0000000000000..6740e89c54e17 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/ca.crt @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDmzCCAoOgAwIBAgIUf+rP48iL39yGlAfFQTIp5bmM4uQwDQYJKoZIhvcNAQEL +BQAwXDELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYDVQQHDARDaXR5 +MQ4wDAYDVQQKDAVNeU9yZzEPMA0GA1UECwwGTXlVbml0MQ0wCwYDVQQDDARNeUNB +MCAXDTI0MTIwMzExMDQxMVoYDzMwMjQwNDA1MTEwNDExWjBcMQswCQYDVQQGEwJV +UzEOMAwGA1UECAwFU3RhdGUxDTALBgNVBAcMBENpdHkxDjAMBgNVBAoMBU15T3Jn +MQ8wDQYDVQQLDAZNeVVuaXQxDTALBgNVBAMMBE15Q0EwggEiMA0GCSqGSIb3DQEB +AQUAA4IBDwAwggEKAoIBAQCliiXIcSmxXAAq2k/XjcZniDgEDCxWKZGiV8JBiJwY +MMBJtqcVzWfiDpO2u6d1dfGb6utlRW+1dnwupzURCMmZff4bqlPx4ZejRXDrWzKz +08WSpDVZwC2H5XOllwK36Cn4gvPRe3YWVcdDGHy7GL+zsJENvawJj0BH952MU4bk +sV52zEkN291bfN9sSYfT1NCJuLPM0Qsf97DeQ+wHXEw+t4XVMF3FQbciQp0y6CnA +wfFFN14WDiWxukP1I3kuDYYA6h/WJCQMp5rU2NCB9nIQrulYRxFaepMYENLxgAyj +gFaoRh2Kt2k7XKv6WOa6CmYm2dZERPlbA+oNAHkaHw6lAgMBAAGjUzBRMB0GA1Ud +DgQWBBSN+3vRlXGjs6c+rN94qgEnkPLl3DAfBgNVHSMEGDAWgBSN+3vRlXGjs6c+ +rN94qgEnkPLl3DAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQAb +L40Oe2b/2xdUSyjqXJceVxaeA291fCpnu1C1JliP0hNI3fu9jjZhXHQoYub/4mod +8lriEDIcOCCiUfmi404akpqQHuBmOHaKEOtaaQkezjPsYnUra+O2ssqUo2zto5bK +gR0LGsb+4AO0bDvq+QVI6kEQqAAIf6qC+kpg/jV4iKJ1J6Qw4R3QppYBm6SQcfvI +hfUfDSO6SNfy0f/ZVCavbJIP9zG/BfAD9DEERocw03PiN5bm4IXJ3HH8rxyuBfJ5 +Eg/fPP5TlZ2H7Kqb3VgVBGWJtNXWmJphHyraBJTEuxgXWvl6AaW0P/3dsJi3rfdD +zDIT7AmENLCom8Gl0bgM +-----END CERTIFICATE----- diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/server.crt b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/server.crt new file mode 100644 index 0000000000000..92c91f2d613b0 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/server.crt @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDtTCCAp2gAwIBAgIUUhmhZP94nIowrg2EarzfEBp6W1EwDQYJKoZIhvcNAQEL +BQAwXDELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYDVQQHDARDaXR5 +MQ4wDAYDVQQKDAVNeU9yZzEPMA0GA1UECwwGTXlVbml0MQ0wCwYDVQQDDARNeUNB +MCAXDTI0MTIwMzExMDQxMVoYDzMwMjQwNDA1MTEwNDExWjBrMQswCQYDVQQGEwJV +UzEOMAwGA1UECAwFU3RhdGUxDTALBgNVBAcMBENpdHkxDjAMBgNVBAoMBU15T3Jn +MQ8wDQYDVQQLDAZNeVVuaXQxHDAaBgNVBAMME3NlcnZlci5teWRvbWFpbi5jb20w +ggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDSxC4zCC4GFZbX+fdFgWbL +sj4PortyOM7mzRjNaQ3M0FTSEy5xET9C2qFlBCJ7AL7DlbSLmSckYY/FkdfMqNN2 ++NZ0Dy2d6bZN+ly5N/QBVnyS/5HVC3MXa6Y2BmFXiBnczWfGBwj+uVHlKOUWUyNi +EyUkhuPwtYXkFmJoqBxJSPC6cxX6NzMujnwCF18dUf0Vra44osu4moaovmg3c9jM +cBtmafFs9F54FoAEuLotjISVEa7VY6th5RxXJHpgas+0R5EBddGYKbTRiUYjht7r +pS+An0ey02oOjEWdqLnQSg/SUGKuRXULyE5l1A0HfNQtvepUQotb9ull1F7OrbfB +AgMBAAGjXjBcMBoGA1UdEQQTMBGCBE15Q0GCCWxvY2FsaG9zdDAdBgNVHQ4EFgQU +vnCLWjre4jqkKzC24psCPh1oIQwwHwYDVR0jBBgwFoAUjft70ZVxo7OnPqzfeKoB +J5Dy5dwwDQYJKoZIhvcNAQELBQADggEBAJCiJgtTw/7b/g3QnJfM4teQkFS440Ii +weqQJMoP6als8Fc3opPKv9eC5w0wqaLlIdwJjzGM5PmCAtGVafo22TbqhZyQdzQu +TUKv1DaVF0JBVAGVxTSDIK9r5Ww4mDAQnQENLC6soS3AvYDEi+8667YLoNNdhRCX +q2D5v76UN45idiShppxOw53whsvpHv+wyqcdse7DhgM9boCbx51Uvv3l/AEToyaj +S1xeIkBwNpSYU0ax2Lr1j2yoKbzAa3MHy8Php+T5CGji02+HwwlvlPDLtw8q5gHw +BLSwlAHgclPxUTWNNoCqjfX8Bi083+QDCLm0rgQ45xljNDbFAF1Y5hA= +-----END CERTIFICATE----- diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/server.key b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/server.key new file mode 100644 index 0000000000000..2cdf5750a4753 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/server.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQDSxC4zCC4GFZbX ++fdFgWbLsj4PortyOM7mzRjNaQ3M0FTSEy5xET9C2qFlBCJ7AL7DlbSLmSckYY/F +kdfMqNN2+NZ0Dy2d6bZN+ly5N/QBVnyS/5HVC3MXa6Y2BmFXiBnczWfGBwj+uVHl +KOUWUyNiEyUkhuPwtYXkFmJoqBxJSPC6cxX6NzMujnwCF18dUf0Vra44osu4moao +vmg3c9jMcBtmafFs9F54FoAEuLotjISVEa7VY6th5RxXJHpgas+0R5EBddGYKbTR +iUYjht7rpS+An0ey02oOjEWdqLnQSg/SUGKuRXULyE5l1A0HfNQtvepUQotb9ull +1F7OrbfBAgMBAAECggEAAxbZuuESGGAMMm9HLGhKHgbHU8gnv2Phdbrka+SYBYg5 +UYzTHLh3FwEsjd4VnaweJ7CN1WDb1NvWmTum/DCebJ1HKqtjKLAZfk8q2TLGmXdL +pzWOdQ8MX1fKP2sIlcl0kFbNCE8vprjneDyBLtqOK36eiAh/fl6BQ12QAMLjyv/L +OwXSY4ESs/RzxRzFgdT98cDZFL7y0FVIjJo/Q5lfW9UwwSfw8tOLNXKTYwPHqIfJ +NjfWD7IqztQlnanyRXv5dScp80i8p9qgH0i8YfVBHZDeOmHGLcltilLRZ0dQ/X0g +Lrr0aIO3iLhmTIkJRzUnGeyvDjxcPINvRSBBwXy04QKBgQDpFJa/EwSsWj8586oh +xgm0Z3q+FiEeCe7aLLPcXAS2EDvix5ibJDT2y1Aa/kXq25S53npa/7Ov6TJs5H4g +eyshDtR1wVhz+rIggREiX/sagkhwnNsssUZFv5t9PdnaFXpVnH49m5Qc8HO3owtN +t8EGSRcAQ4o/fLWLs51qd38cIQKBgQDnfd8YPyDQ03xDC/3+Qrypyc/xhGnCuj7w +ZeA5iEyTnnNxL0a0B6PWcSk2BZReMNQKgYtipnsOQKtwHMttxtXYs/VQpeB4KoWE +zEwW0fV3MMsXN+nVJlEZnVaTbmYXknjeZrh/rNjsY96yxw8NtvAuYSpnqtr3N2nd +iMQ3G/QnoQKBgGMi+bdNvIgeXpQkmrGAzTHpbaCaQv3G1cwAhYPts6dIomAj6znZ +nZl3ApxomI57VPf1s+8uoVvqASOl0Cu6l66Y4y8uzJOQBuGiZApN7rzouy0C2opY +4H3cMKOFgjqrNfxh8qP7n3TrpRxvgehNhxFIVzsqfwvf3EwOWp8lMnBhAoGAZ25E +Ge9K2ENGCCb5i3uCFFLJiF3ja1AQAxVhxBL0NBjd97pp2tJ3D79r7Gk9y4ABndgX +0TIVVV7ruqIC+r+WmMZ/W1NiIg7NrXIipSeWh3TTqUIgRk5iehFkt2biUrHtM2Gu +Gc2+9pAA1tw+C6CrW+2qJrueLksiEAulsAHba0ECgYBIgIiY+Gx+XecEgCwAhWcn +GzNDAAlA4IgBjHpUtIByflzQDqlECKXjPbVBKfyq6eLt40upFmQCLsn+AkiQau8A +3cFAK9wJOAHv9KuWDrbHyhRE9CrJ6BqsY2goC3LiFCTgJy1TrRl6CDaFzHivONwF +LNPflYk5s376UWqxC+HtIA== +-----END PRIVATE KEY----- diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightConnectorTestBase.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightConnectorTestBase.cpp new file mode 100644 index 0000000000000..1fecf9e31977a --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightConnectorTestBase.cpp @@ -0,0 +1,94 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightConnectorTestBase.h" +#include +#include +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h" +#include "presto_cpp/main/connectors/arrow_flight/Macros.h" +#include "velox/exec/tests/utils/PortUtil.h" + +using namespace facebook::velox; +using namespace facebook::velox::exec::test; + +namespace facebook::presto::test { + +void ArrowFlightConnectorTestBase::SetUp() { + OperatorTestBase::SetUp(); + + if (!velox::connector::hasConnectorFactory( + presto::ArrowFlightConnectorFactory::kArrowFlightConnectorName)) { + velox::connector::registerConnectorFactory( + std::make_shared()); + } + velox::connector::registerConnector( + velox::connector::getConnectorFactory( + ArrowFlightConnectorFactory::kArrowFlightConnectorName) + ->newConnector(kFlightConnectorId, config_)); + + ArrowFlightConfig config(config_); + if (config.defaultServerPort().has_value()) { + port_ = config.defaultServerPort().value(); + } else { + port_ = getFreePort(); + } + AFC_ASSIGN_OR_RAISE( + auto serverLocation, + config.defaultServerSslEnabled() + ? arrow::flight::Location::ForGrpcTls(kBindHost, port_) + : arrow::flight::Location::ForGrpcTcp(kBindHost, port_)); + + arrow::flight::FlightServerOptions serverOptions(serverLocation); + server_ = std::make_unique(); + setFlightServerOptions(&serverOptions); + ASSERT_OK(server_->Init(serverOptions)); +} + +void ArrowFlightConnectorTestBase::TearDown() { + ASSERT_OK(server_->Shutdown()); + velox::connector::unregisterConnector(kFlightConnectorId); + OperatorTestBase::TearDown(); +} + +std::vector> +ArrowFlightConnectorTestBase::makeSplits( + const std::initializer_list& tickets) { + ArrowFlightConfig config(config_); + AFC_ASSIGN_OR_RAISE( + auto loc, + config.defaultServerSslEnabled() + ? arrow::flight::Location::ForGrpcTls(kConnectHost, port_) + : arrow::flight::Location::ForGrpcTcp(kConnectHost, port_)); + return makeSplits(tickets, {loc}); +} + +std::vector> +ArrowFlightConnectorTestBase::makeSplits( + const std::initializer_list& tickets, + const std::vector& locations) { + std::vector> splits; + splits.reserve(tickets.size()); + for (auto& ticket : tickets) { + arrow::flight::FlightEndpoint flightEndpoint; + flightEndpoint.ticket.ticket = ticket; + flightEndpoint.locations = locations; + AFC_ASSIGN_OR_RAISE( + auto flightEndpointStr, flightEndpoint.SerializeToString()); + auto flightEndpointBytes = folly::base64Encode(flightEndpointStr); + splits.push_back(std::make_shared( + kFlightConnectorId, flightEndpointBytes)); + } + return splits; +} + +} // namespace facebook::presto::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightConnectorTestBase.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightConnectorTestBase.h new file mode 100644 index 0000000000000..9c928eb1d2ad0 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightConnectorTestBase.h @@ -0,0 +1,79 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/TestingArrowFlightServer.h" +#include "velox/common/config/Config.h" +#include "velox/connectors/Connector.h" +#include "velox/exec/tests/utils/OperatorTestBase.h" + +namespace facebook::presto::test { + +static const std::string kFlightConnectorId = "test-flight"; + +/// Creates and registers an Arrow Flight connector and +/// spawns a Flight server for testing. +/// Initially there is no data in the Flight server, +/// tests should call ArrowFlightConnectorTestBase::updateTables to populate it. +class ArrowFlightConnectorTestBase + : public velox::exec::test::OperatorTestBase { + public: + static constexpr const char* kBindHost = "127.0.0.1"; + static constexpr const char* kConnectHost = "localhost"; + + void SetUp() override; + + void TearDown() override; + + /// Create splits for this test flight server. + std::vector> makeSplits( + const std::initializer_list& tokens); + + /// Convenience function for creating splits with endpoint locations. + static std::vector> + makeSplits( + const std::initializer_list& tokens, + const std::vector& locations); + + /// Add (or update) a table in the test flight server. + void updateTable( + const std::string& name, + const std::shared_ptr& table) { + server_->updateTable(name, table); + } + + void setBatchSize(int64_t batchSize) { + server_->setBatchSize(batchSize); + } + + virtual void setFlightServerOptions( + arrow::flight::FlightServerOptions* serverOptions) {} + + protected: + explicit ArrowFlightConnectorTestBase( + std::shared_ptr config) + : config_{std::move(config)} {} + + ArrowFlightConnectorTestBase() + : ArrowFlightConnectorTestBase( + std::make_shared( + std::unordered_map())) {} + + uint32_t port_; + std::unique_ptr server_; + std::shared_ptr config_; +}; + +} // namespace facebook::presto::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightPlanBuilder.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightPlanBuilder.cpp new file mode 100644 index 0000000000000..5280f9c56832d --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightPlanBuilder.cpp @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightPlanBuilder.h" +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h" + +namespace facebook::presto::test { +namespace { +const std::string kFlightConnectorId = "test-flight"; +} + +velox::exec::test::PlanBuilder& ArrowFlightPlanBuilder::flightTableScan( + const velox::RowTypePtr& outputType, + std::unordered_map< + std::string, + std::shared_ptr> assignments, + bool createDefaultColumnHandles) { + if (createDefaultColumnHandles) { + for (const auto& name : outputType->names()) { + // Provide unaliased defaults for unmapped columns. + // `emplace` won't modify the map if the key already exists, + // so existing aliases are kept. + assignments.emplace( + name, std::make_shared(name)); + } + } + + return startTableScan() + .tableHandle(std::make_shared(kFlightConnectorId)) + .outputType(outputType) + .assignments(std::move(assignments)) + .endTableScan(); +} + +} // namespace facebook::presto::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightPlanBuilder.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightPlanBuilder.h new file mode 100644 index 0000000000000..5eda2c60aac16 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightPlanBuilder.h @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/exec/tests/utils/PlanBuilder.h" + +namespace facebook::presto::test { + +class ArrowFlightPlanBuilder : public velox::exec::test::PlanBuilder { + public: + /// @brief Add a table scan node to the Plan, using the Flight connector + /// @param outputType The output type of the table scan node + /// @param assignments mapping from the column aliases to real column handles + /// @param createDefaultColumnHandles If true, generate column handles for + /// for the columns which don't have an entry in assignments + velox::exec::test::PlanBuilder& flightTableScan( + const velox::RowTypePtr& outputType, + std::unordered_map< + std::string, + std::shared_ptr> assignments = {}, + bool createDefaultColumnHandles = true); +}; + +} // namespace facebook::presto::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/CMakeLists.txt b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/CMakeLists.txt new file mode 100644 index 0000000000000..b6d2337a2d301 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/CMakeLists.txt @@ -0,0 +1,19 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +add_library( + presto_flight_connector_test_lib + TestingArrowFlightServer.cpp ArrowFlightConnectorTestBase.cpp Utils.cpp + ArrowFlightPlanBuilder.cpp) + +target_link_libraries( + presto_flight_connector_test_lib arrow presto_flight_connector + velox_exception presto_flight_connector_utils velox_exec_test_lib) diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/TestingArrowFlightServer.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/TestingArrowFlightServer.cpp new file mode 100644 index 0000000000000..c7db0df572fc9 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/TestingArrowFlightServer.cpp @@ -0,0 +1,36 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/TestingArrowFlightServer.h" + +namespace facebook::presto::test { + +arrow::Status TestingArrowFlightServer::DoGet( + const arrow::flight::ServerCallContext& context, + const arrow::flight::Ticket& request, + std::unique_ptr* stream) { + auto it = tables_.find(request.ticket); + if (it == tables_.end()) { + return arrow::Status::KeyError( + "requested table does not exist: ", request.ticket); + } + auto& table = it->second; + auto reader = std::make_shared(table); + if (batchSize_.has_value()) { + reader->set_chunksize(batchSize_.value()); + } + *stream = std::make_unique(reader); + return arrow::Status::OK(); +} + +} // namespace facebook::presto::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/TestingArrowFlightServer.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/TestingArrowFlightServer.h new file mode 100644 index 0000000000000..6781b43974108 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/TestingArrowFlightServer.h @@ -0,0 +1,55 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +namespace facebook::presto::test { + +/// Test Flight server which supports DoGet operations. +/// Maintains a list of named arrow tables, +/// +/// Normally, the tickets would be obtained by calling GetFlightInfo, +/// but since this is done by the coordinator this part is omitted. +/// Instead, the ticket is simply the name of the table to fetch. +class TestingArrowFlightServer : public arrow::flight::FlightServerBase { + public: + TestingArrowFlightServer() = default; + + void updateTable( + const std::string& name, + const std::shared_ptr& table) { + tables_.emplace(name, table); + } + + void removeTable(const std::string& name) { + tables_.erase(name); + } + + void setBatchSize(int64_t batchSize) { + batchSize_ = std::make_optional(batchSize); + } + + arrow::Status DoGet( + const arrow::flight::ServerCallContext& context, + const arrow::flight::Ticket& request, + std::unique_ptr* stream) override; + + private: + std::unordered_map> tables_; + std::optional batchSize_; +}; + +} // namespace facebook::presto::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.cpp new file mode 100644 index 0000000000000..48d209822fec0 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.cpp @@ -0,0 +1,105 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "Utils.h" +#include +#include +#include "velox/type/StringView.h" + +namespace facebook::presto::test { + +ArrowArrayPtr makeDecimalArray( + const std::vector& decimalValues, + int precision, + int scale) { + auto decimalType = arrow::decimal(precision, scale); + auto builder = + arrow::Decimal128Builder(decimalType, arrow::default_memory_pool()); + + for (const auto& value : decimalValues) { + arrow::Decimal128 dec(value); + AFC_RAISE_NOT_OK(builder.Append(dec)); + } + AFC_RETURN_OR_RAISE(builder.Finish()); +} + +ArrowArrayPtr makeTimestampArray( + const std::vector& values, + arrow::TimeUnit::type timeUnit, + arrow::MemoryPool* memory_pool) { + arrow::TimestampBuilder builder(arrow::timestamp(timeUnit), memory_pool); + AFC_RAISE_NOT_OK(builder.AppendValues(values)); + AFC_RETURN_OR_RAISE(builder.Finish()); +} + +std::vector makeStringViewVector( + const std::vector& values) { + std::vector stringViewVector; + stringViewVector.reserve(values.size()); + for (const auto& value : values) { + stringViewVector.emplace_back(value); + } + return stringViewVector; +} + +ArrowArrayPtr makeStringArray(const std::vector& values) { + auto builder = arrow::StringBuilder{}; + AFC_RAISE_NOT_OK(builder.AppendValues(values)); + AFC_RETURN_OR_RAISE(builder.Finish()); +} + +ArrowArrayPtr makeBinaryArray(const std::vector& values) { + auto builder = arrow::BinaryBuilder{}; + AFC_RAISE_NOT_OK(builder.AppendValues(values)); + AFC_RETURN_OR_RAISE(builder.Finish()); +} + +ArrowArrayPtr makeBooleanArray(const std::vector& values) { + auto builder = arrow::BooleanBuilder{}; + AFC_RAISE_NOT_OK(builder.AppendValues(values)); + AFC_RETURN_OR_RAISE(builder.Finish()); +} + +std::shared_ptr makeRecordBatch( + const std::vector& names, + const arrow::ArrayVector& arrays) { + VELOX_CHECK_EQ(names.size(), arrays.size()); + + auto numRows = (!arrays.empty()) ? (arrays[0]->length()) : 0; + arrow::FieldVector fields{}; + for (int i = 0; i < arrays.size(); i++) { + VELOX_CHECK_EQ(arrays[i]->length(), numRows); + fields.push_back( + std::make_shared(names[i], arrays[i]->type())); + } + + auto schema = arrow::schema(fields); + return arrow::RecordBatch::Make(schema, numRows, arrays); +} + +std::shared_ptr makeArrowTable( + const std::vector& names, + const arrow::ArrayVector& arrays) { + AFC_RETURN_OR_RAISE( + arrow::Table::FromRecordBatches({makeRecordBatch(names, arrays)})); +} + +std::string readFile(const std::string& path) { + std::ifstream file(path); + VELOX_CHECK( + file.is_open(), "Could not open file \"{}\": {}", path, strerror(errno)); + return { + std::istreambuf_iterator(file), std::istreambuf_iterator()}; +} + +} // namespace facebook::presto::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.h new file mode 100644 index 0000000000000..fb27b03258e34 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.h @@ -0,0 +1,65 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include "presto_cpp/main/connectors/arrow_flight/Macros.h" +#include "velox/common/base/Exceptions.h" + +namespace facebook::velox { +class StringView; +} // namespace facebook::velox + +namespace facebook::presto::test { + +using ArrowArrayPtr = std::shared_ptr; + +template +ArrowArrayPtr makeNumericArray(const std::vector& values) { + auto builder = arrow::NumericBuilder{}; + AFC_RAISE_NOT_OK(builder.AppendValues(values)); + AFC_RETURN_OR_RAISE(builder.Finish()); +} + +ArrowArrayPtr makeDecimalArray( + const std::vector& decimalValues, + int precision, + int scale); + +ArrowArrayPtr makeTimestampArray( + const std::vector& values, + arrow::TimeUnit::type timeUnit, + arrow::MemoryPool* memory_pool = arrow::default_memory_pool()); + +std::vector makeStringViewVector( + const std::vector& values); + +ArrowArrayPtr makeStringArray(const std::vector& values); + +ArrowArrayPtr makeBinaryArray(const std::vector& values); + +ArrowArrayPtr makeBooleanArray(const std::vector& values); + +std::shared_ptr makeRecordBatch( + const std::vector& names, + const arrow::ArrayVector& arrays); + +std::shared_ptr makeArrowTable( + const std::vector& names, + const arrow::ArrayVector& arrays); + +std::string readFile(const std::string& path); + +} // namespace facebook::presto::test diff --git a/presto-native-execution/presto_cpp/main/operators/tests/CMakeLists.txt b/presto-native-execution/presto_cpp/main/operators/tests/CMakeLists.txt index 0547dcaeeafcb..19b0489f97174 100644 --- a/presto-native-execution/presto_cpp/main/operators/tests/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/operators/tests/CMakeLists.txt @@ -20,6 +20,7 @@ add_test(presto_operators_test presto_operators_test) target_link_libraries( presto_operators_test + presto_connectors presto_operators_plan_builder presto_operators presto_protocol diff --git a/presto-native-execution/presto_cpp/main/tests/QueryContextManagerTest.cpp b/presto-native-execution/presto_cpp/main/tests/QueryContextManagerTest.cpp index 6c4616dfb7ec3..25637184986f3 100644 --- a/presto-native-execution/presto_cpp/main/tests/QueryContextManagerTest.cpp +++ b/presto-native-execution/presto_cpp/main/tests/QueryContextManagerTest.cpp @@ -61,8 +61,10 @@ TEST_F(QueryContextManagerTest, nativeSessionProperties) { {"native_expression_max_array_size_in_reduce", "99999"}, {"native_expression_max_compiled_regexes", "54321"}, }}; + protocol::TaskUpdateRequest updateRequest; + updateRequest.session = session; auto queryCtx = taskManager_->getQueryContextManager()->findOrCreateQueryCtx( - taskId, session); + taskId, updateRequest); EXPECT_EQ(queryCtx->queryConfig().maxSpillLevel(), 2); EXPECT_EQ(queryCtx->queryConfig().spillCompressionKind(), "NONE"); EXPECT_FALSE(queryCtx->queryConfig().joinSpillEnabled()); @@ -84,8 +86,10 @@ TEST_F(QueryContextManagerTest, nativeConnectorSessionProperties) { {"native_stats_based_filter_reorder_disabled", "true"}, {"orc_max_merge_distance", "512kB"}}; session.catalogProperties.emplace("hive", hiveSessions); + protocol::TaskUpdateRequest updateRequest; + updateRequest.session = session; auto queryCtx = taskManager_->getQueryContextManager()->findOrCreateQueryCtx( - taskId, session); + taskId, updateRequest); EXPECT_EQ( queryCtx->connectorSessionProperties().at("hive")->get( "orc_max_merge_distance"), @@ -102,8 +106,10 @@ TEST_F(QueryContextManagerTest, defaultSessionProperties) { protocol::TaskId taskId = "scan.0.0.1.0"; protocol::SessionRepresentation session{.systemProperties = {}}; + protocol::TaskUpdateRequest updateRequest; + updateRequest.session = session; auto queryCtx = taskManager_->getQueryContextManager()->findOrCreateQueryCtx( - taskId, session); + taskId, updateRequest); const auto& queryConfig = queryCtx->queryConfig(); EXPECT_EQ(queryConfig.maxSpillLevel(), defaultQC->maxSpillLevel()); EXPECT_EQ( @@ -124,9 +130,11 @@ TEST_F(QueryContextManagerTest, overridingSessionProperties) { const auto& systemConfig = SystemConfig::instance(); { protocol::SessionRepresentation session{.systemProperties = {}}; + protocol::TaskUpdateRequest updateRequest; + updateRequest.session = session; auto queryCtx = taskManager_->getQueryContextManager()->findOrCreateQueryCtx( - taskId, session); + taskId, updateRequest); // When session properties are not explicitly set, they should be set to // system config values. EXPECT_EQ( @@ -156,9 +164,11 @@ TEST_F(QueryContextManagerTest, overridingSessionProperties) { {"spill_enabled", "true"}, {"aggregation_spill_enabled", "false"}, {"join_spill_enabled", "true"}}}; + protocol::TaskUpdateRequest updateRequest; + updateRequest.session = session; auto queryCtx = taskManager_->getQueryContextManager()->findOrCreateQueryCtx( - taskId, session); + taskId, updateRequest); EXPECT_EQ( queryCtx->queryConfig().queryMaxMemoryPerNode(), 1UL * 1024 * 1024 * 1024); @@ -192,6 +202,8 @@ TEST_F(QueryContextManagerTest, overridingSessionProperties) { TEST_F(QueryContextManagerTest, duplicateQueryRootPoolName) { const protocol::TaskId fakeTaskId = "scan.0.0.1.0"; const protocol::SessionRepresentation fakeSession{.systemProperties = {}}; + protocol::TaskUpdateRequest fakeUpdateRequest; + fakeUpdateRequest.session = fakeSession; auto* queryCtxManager = taskManager_->getQueryContextManager(); struct { bool hasPendingReference; @@ -215,7 +227,7 @@ TEST_F(QueryContextManagerTest, duplicateQueryRootPoolName) { queryCtxManager->testingClearCache(); auto queryCtx = - queryCtxManager->findOrCreateQueryCtx(fakeTaskId, fakeSession); + queryCtxManager->findOrCreateQueryCtx(fakeTaskId, fakeUpdateRequest); const auto poolName = queryCtx->pool()->name(); ASSERT_THAT(poolName, testing::HasSubstr("scan_")); if (!testData.hasPendingReference) { @@ -225,7 +237,7 @@ TEST_F(QueryContextManagerTest, duplicateQueryRootPoolName) { queryCtxManager->testingClearCache(); } auto newQueryCtx = - queryCtxManager->findOrCreateQueryCtx(fakeTaskId, fakeSession); + queryCtxManager->findOrCreateQueryCtx(fakeTaskId, fakeUpdateRequest); const auto newPoolName = newQueryCtx->pool()->name(); ASSERT_THAT(newPoolName, testing::HasSubstr("scan_")); if (testData.expectedNewPoolName) { diff --git a/presto-native-execution/presto_cpp/main/tests/ServerOperationTest.cpp b/presto-native-execution/presto_cpp/main/tests/ServerOperationTest.cpp index e0c73446af4f0..68c820486d953 100644 --- a/presto-native-execution/presto_cpp/main/tests/ServerOperationTest.cpp +++ b/presto-native-execution/presto_cpp/main/tests/ServerOperationTest.cpp @@ -177,7 +177,7 @@ TEST_F(ServerOperationTest, taskEndpoint) { planFragment, true, taskManager->getQueryContextManager()->findOrCreateQueryCtx( - taskId, updateRequest.session), + taskId, updateRequest), 0); }; std::vector taskIds = {"task_0.0.0.0.0", "task_1.0.0.0.0"}; diff --git a/presto-native-execution/presto_cpp/main/tests/TaskManagerTest.cpp b/presto-native-execution/presto_cpp/main/tests/TaskManagerTest.cpp index 599f8f6a8123d..04ee3d9fc0bde 100644 --- a/presto-native-execution/presto_cpp/main/tests/TaskManagerTest.cpp +++ b/presto-native-execution/presto_cpp/main/tests/TaskManagerTest.cpp @@ -18,9 +18,9 @@ #include "folly/experimental/EventCount.h" #include "presto_cpp/main/PrestoExchangeSource.h" #include "presto_cpp/main/TaskResource.h" +#include "presto_cpp/main/connectors/PrestoToVeloxConnector.h" #include "presto_cpp/main/tests/HttpServerWrapper.h" #include "presto_cpp/main/tests/MultableConfigs.h" -#include "presto_cpp/main/types/PrestoToVeloxConnector.h" #include "velox/common/base/Fs.h" #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/file/FileSystems.h" @@ -672,7 +672,7 @@ class TaskManagerTest : public exec::test::OperatorTestBase, bool summarize = true) { auto queryCtx = taskManager_->getQueryContextManager()->findOrCreateQueryCtx( - taskId, updateRequest.session); + taskId, updateRequest); return taskManager_->createOrUpdateTask( taskId, updateRequest, planFragment, summarize, std::move(queryCtx), 0); } diff --git a/presto-native-execution/presto_cpp/main/types/CMakeLists.txt b/presto-native-execution/presto_cpp/main/types/CMakeLists.txt index 5841728512238..e22e2d8ddbd6b 100644 --- a/presto-native-execution/presto_cpp/main/types/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/types/CMakeLists.txt @@ -15,14 +15,14 @@ add_library(presto_type_converter OBJECT TypeParser.cpp) target_link_libraries(presto_type_converter velox_type_parser) add_library( - presto_types OBJECT - PrestoToVeloxQueryPlan.cpp PrestoToVeloxExpr.cpp VeloxPlanValidator.cpp - PrestoToVeloxSplit.cpp PrestoToVeloxConnector.cpp) + presto_types OBJECT PrestoToVeloxQueryPlan.cpp PrestoToVeloxExpr.cpp + VeloxPlanValidator.cpp PrestoToVeloxSplit.cpp) add_dependencies(presto_types presto_operators presto_type_converter velox_type velox_type_fbhive) -target_link_libraries(presto_types presto_type_converter velox_type_fbhive - velox_hive_partition_function velox_tpch_gen velox_functions_json) +target_link_libraries( + presto_types presto_type_converter velox_type_fbhive + velox_hive_partition_function velox_tpch_gen velox_functions_json) set_property(TARGET presto_types PROPERTY JOB_POOL_LINK presto_link_job_pool) diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp index 0bd1a5cf93876..9a3d2755b6435 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp @@ -14,7 +14,7 @@ // clang-format off #include "presto_cpp/main/common/Configs.h" -#include "presto_cpp/main/types/PrestoToVeloxConnector.h" +#include "presto_cpp/main/connectors/PrestoToVeloxConnector.h" #include "presto_cpp/main/types/PrestoToVeloxQueryPlan.h" #include #include "velox/core/QueryCtx.h" diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxSplit.cpp b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxSplit.cpp index 1d11be2e904fc..6ecda241f1478 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxSplit.cpp +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxSplit.cpp @@ -12,7 +12,7 @@ * limitations under the License. */ #include "presto_cpp/main/types/PrestoToVeloxSplit.h" -#include "presto_cpp/main/types/PrestoToVeloxConnector.h" +#include "presto_cpp/main/connectors/PrestoToVeloxConnector.h" #include "velox/exec/Exchange.h" using namespace facebook::velox; diff --git a/presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt b/presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt index 28f73aff40b80..9286f3296bb09 100644 --- a/presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt @@ -20,6 +20,7 @@ target_link_libraries( presto_velox_split_test GTest::gtest GTest::gtest_main + presto_connectors presto_operators presto_protocol velox_dwio_common @@ -48,6 +49,7 @@ target_link_libraries( presto_expressions_test GTest::gtest GTest::gtest_main + presto_connectors $ $ $ @@ -86,6 +88,7 @@ add_test( target_link_libraries( presto_to_velox_connector_test + presto_connectors presto_protocol presto_operators presto_type_converter @@ -123,6 +126,7 @@ add_test( target_link_libraries( presto_to_velox_query_plan_test + presto_connectors presto_operators presto_protocol presto_type_converter diff --git a/presto-native-execution/presto_cpp/main/types/tests/PlanConverterTest.cpp b/presto-native-execution/presto_cpp/main/types/tests/PlanConverterTest.cpp index 715780befa84a..3f9a402ba8809 100644 --- a/presto-native-execution/presto_cpp/main/types/tests/PlanConverterTest.cpp +++ b/presto-native-execution/presto_cpp/main/types/tests/PlanConverterTest.cpp @@ -14,11 +14,11 @@ #include #include "presto_cpp/main/common/tests/test_json.h" +#include "presto_cpp/main/connectors/PrestoToVeloxConnector.h" #include "presto_cpp/main/operators/LocalPersistentShuffle.h" #include "presto_cpp/main/operators/PartitionAndSerialize.h" #include "presto_cpp/main/operators/ShuffleRead.h" #include "presto_cpp/main/operators/ShuffleWrite.h" -#include "presto_cpp/main/types/PrestoToVeloxConnector.h" #include "presto_cpp/main/types/PrestoToVeloxQueryPlan.h" #include "presto_cpp/main/types/tests/TestUtils.h" #include "velox/connectors/hive/TableHandle.h" diff --git a/presto-native-execution/presto_cpp/main/types/tests/PrestoToVeloxConnectorTest.cpp b/presto-native-execution/presto_cpp/main/types/tests/PrestoToVeloxConnectorTest.cpp index 932f48a611f73..a88b235686498 100644 --- a/presto-native-execution/presto_cpp/main/types/tests/PrestoToVeloxConnectorTest.cpp +++ b/presto-native-execution/presto_cpp/main/types/tests/PrestoToVeloxConnectorTest.cpp @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "presto_cpp/main/types/PrestoToVeloxConnector.h" +#include "presto_cpp/main/connectors/PrestoToVeloxConnector.h" #include #include "velox/common/base/tests/GTestUtils.h" diff --git a/presto-native-execution/presto_cpp/main/types/tests/PrestoToVeloxSplitTest.cpp b/presto-native-execution/presto_cpp/main/types/tests/PrestoToVeloxSplitTest.cpp index 5522684b262d2..9bcbf9b4f0542 100644 --- a/presto-native-execution/presto_cpp/main/types/tests/PrestoToVeloxSplitTest.cpp +++ b/presto-native-execution/presto_cpp/main/types/tests/PrestoToVeloxSplitTest.cpp @@ -13,7 +13,7 @@ */ #include "presto_cpp/main/types/PrestoToVeloxSplit.h" #include -#include "presto_cpp/main/types/PrestoToVeloxConnector.h" +#include "presto_cpp/main/connectors/PrestoToVeloxConnector.h" #include "velox/connectors/hive/HiveConnectorSplit.h" using namespace facebook::velox; diff --git a/presto-native-execution/presto_cpp/presto_protocol/Makefile b/presto-native-execution/presto_cpp/presto_protocol/Makefile index 3ee2b4e802b81..09b43df28b4f5 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/Makefile +++ b/presto-native-execution/presto_cpp/presto_protocol/Makefile @@ -45,14 +45,23 @@ presto_protocol-cpp: presto_protocol-json chevron -d connector/tpch/presto_protocol_tpch.json connector/tpch/presto_protocol-json-hpp.mustache >> connector/tpch/presto_protocol_tpch.h clang-format -style=file -i connector/tpch/presto_protocol_tpch.h connector/tpch/presto_protocol_tpch.cpp + # build arrow_flight connector related structs + echo "// DO NOT EDIT : This file is generated by chevron" > connector/arrow_flight/presto_protocol_arrow_flight.cpp + chevron -d connector/arrow_flight/presto_protocol_arrow_flight.json connector/arrow_flight/presto_protocol-json-cpp.mustache >> connector/arrow_flight/presto_protocol_arrow_flight.cpp + echo "// DO NOT EDIT : This file is generated by chevron" > connector/arrow_flight/presto_protocol_arrow_flight.h + chevron -d connector/arrow_flight/presto_protocol_arrow_flight.json connector/arrow_flight/presto_protocol-json-hpp.mustache >> connector/arrow_flight/presto_protocol_arrow_flight.h + clang-format -style=file -i connector/arrow_flight/presto_protocol_arrow_flight.h connector/arrow_flight/presto_protocol_arrow_flight.cpp + presto_protocol-json: ./java-to-struct-json.py --config core/presto_protocol_core.yml core/special/*.java core/special/*.inc -j | jq . > core/presto_protocol_core.json ./java-to-struct-json.py --config connector/hive/presto_protocol_hive.yml connector/hive/special/*.inc -j | jq . > connector/hive/presto_protocol_hive.json ./java-to-struct-json.py --config connector/iceberg/presto_protocol_iceberg.yml connector/iceberg/special/*.inc -j | jq . > connector/iceberg/presto_protocol_iceberg.json ./java-to-struct-json.py --config connector/tpch/presto_protocol_tpch.yml connector/tpch/special/*.inc -j | jq . > connector/tpch/presto_protocol_tpch.json + ./java-to-struct-json.py --config connector/arrow_flight/presto_protocol_arrow_flight.yml connector/arrow_flight/special/*.inc -j | jq . > connector/arrow_flight/presto_protocol_arrow_flight.json presto_protocol.proto: presto_protocol-json pystache presto_protocol-protobuf.mustache core/presto_protocol_core.json > core/presto_protocol_core.proto pystache presto_protocol-protobuf.mustache connector/hive/presto_protocol_hive.json > connector/hive/presto_protocol_hive.proto pystache presto_protocol-protobuf.mustache connector/iceberg/presto_protocol_iceberg.json > connector/iceberg/presto_protocol_iceberg.proto pystache presto_protocol-protobuf.mustache connector/tpch/presto_protocol_tpch.json > connector/tpch/presto_protocol_tpch.proto + pystache presto_protocol-protobuf.mustache connector/arrow_flight/presto_protocol_arrow_flight.json > connector/arrow_flight/presto_protocol_arrow_flight.proto diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/ArrowFlightConnectorProtocol.h b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/ArrowFlightConnectorProtocol.h new file mode 100644 index 0000000000000..95cda16115695 --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/ArrowFlightConnectorProtocol.h @@ -0,0 +1,29 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.h" +#include "presto_cpp/presto_protocol/core/ConnectorProtocol.h" + +namespace facebook::presto::protocol::arrow_flight { +using ArrowConnectorProtocol = ConnectorProtocolTemplate< + ArrowTableHandle, + ArrowTableLayoutHandle, + ArrowColumnHandle, + NotImplemented, + NotImplemented, + ArrowSplit, + NotImplemented, + ArrowTransactionHandle, + NotImplemented>; +} // namespace facebook::presto::protocol::arrow_flight diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol-json-cpp.mustache b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol-json-cpp.mustache new file mode 100644 index 0000000000000..b6ecb68507285 --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol-json-cpp.mustache @@ -0,0 +1,150 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// presto_protocol.prolog.cpp +// + +{{#.}} +{{#comment}} +{{comment}} +{{/comment}} +{{/.}} + +#include + +#include "presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.h" +using namespace std::string_literals; + +// ArrowTransactionHandle is special since +// the corresponding class in Java is an enum. + +namespace facebook::presto::protocol::arrow_flight { + +void to_json(json& j, const ArrowTransactionHandle& p) { + j = json::array(); + j.push_back(p._type); + j.push_back(p.instance); +} + +void from_json(const json& j, ArrowTransactionHandle& p) { + j[0].get_to(p._type); + j[1].get_to(p.instance); +} +} // namespace facebook::presto::protocol::arrow_flight +{{#.}} +{{#cinc}} +{{&cinc}} +{{/cinc}} +{{^cinc}} +{{#struct}} +namespace facebook::presto::protocol::arrow_flight { + {{#super_class}} + {{&class_name}}::{{&class_name}}() noexcept { + _type = "{{json_key}}"; + } + {{/super_class}} + + void to_json(json& j, const {{&class_name}}& p) { + j = json::object(); + {{#super_class}} + j["@type"] = "{{&json_key}}"; + {{/super_class}} + {{#fields}} + to_json_key(j, "{{&field_name}}", p.{{field_name}}, "{{&class_name}}", "{{&field_text}}", "{{&field_name}}"); + {{/fields}} + } + + void from_json(const json& j, {{&class_name}}& p) { + {{#super_class}} + p._type = j["@type"]; + {{/super_class}} + {{#fields}} + from_json_key(j, "{{&field_name}}", p.{{field_name}}, "{{&class_name}}", "{{&field_text}}", "{{&field_name}}"); + {{/fields}} + } +} +{{/struct}} +{{#enum}} +namespace facebook::presto::protocol::arrow_flight { + //Loosly copied this here from NLOHMANN_JSON_SERIALIZE_ENUM() + + // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays + static const std::pair<{{&class_name}}, json> + {{&class_name}}_enum_table[] = { // NOLINT: cert-err58-cpp + {{#elements}} + { {{&class_name}}::{{&element}}, "{{&element}}" }{{^_last}},{{/_last}} + {{/elements}} + }; + void to_json(json& j, const {{&class_name}}& e) + { + static_assert(std::is_enum<{{&class_name}}>::value, "{{&class_name}} must be an enum!"); + const auto* it = std::find_if(std::begin({{&class_name}}_enum_table), std::end({{&class_name}}_enum_table), + [e](const std::pair<{{&class_name}}, json>& ej_pair) -> bool + { + return ej_pair.first == e; + }); + j = ((it != std::end({{&class_name}}_enum_table)) ? it : std::begin({{&class_name}}_enum_table))->second; + } + void from_json(const json& j, {{&class_name}}& e) + { + static_assert(std::is_enum<{{&class_name}}>::value, "{{&class_name}} must be an enum!"); + const auto* it = std::find_if(std::begin({{&class_name}}_enum_table), std::end({{&class_name}}_enum_table), + [&j](const std::pair<{{&class_name}}, json>& ej_pair) -> bool + { + return ej_pair.second == j; + }); + e = ((it != std::end({{&class_name}}_enum_table)) ? it : std::begin({{&class_name}}_enum_table))->first; + } +} +{{/enum}} +{{#abstract}} +namespace facebook::presto::protocol::arrow_flight { + void to_json(json& j, const std::shared_ptr<{{&class_name}}>& p) { + if ( p == nullptr ) { + return; + } + String type = p->_type; + + {{#subclasses}} + if ( type == "{{&key}}" ) { + j = *std::static_pointer_cast<{{&type}}>(p); + return; + } + {{/subclasses}} + + throw TypeError(type + " no abstract type {{&class_name}} {{&key}}"); + } + + void from_json(const json& j, std::shared_ptr<{{&class_name}}>& p) { + String type; + try { + type = p->getSubclassKey(j); + } catch (json::parse_error &e) { + throw ParseError(std::string(e.what()) + " {{&class_name}} {{&key}} {{&class_name}}"); + } + + {{#subclasses}} + if ( type == "{{&key}}" ) { + std::shared_ptr<{{&type}}> k = std::make_shared<{{&type}}>(); + j.get_to(*k); + p = std::static_pointer_cast<{{&class_name}}>(k); + return; + } + {{/subclasses}} + + throw TypeError(type + " no abstract type {{&class_name}} {{&key}}"); + } +} +{{/abstract}} +{{/cinc}} +{{/.}} diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol-json-hpp.mustache b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol-json-hpp.mustache new file mode 100644 index 0000000000000..be08bd9e491c2 --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol-json-hpp.mustache @@ -0,0 +1,76 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +{{#.}} +{{#comment}} +{{comment}} +{{/comment}} +{{/.}} + +#include +#include +#include +#include + +#include "presto_cpp/external/json/nlohmann/json.hpp" +#include "presto_cpp/presto_protocol/core/presto_protocol_core.h" + +// ArrowTransactionHandle is special since +// the corresponding class in Java is an enum. + +namespace facebook::presto::protocol::arrow_flight { + +struct ArrowTransactionHandle : public ConnectorTransactionHandle { + String instance = {}; +}; + +void to_json(json& j, const ArrowTransactionHandle& p); + +void from_json(const json& j, ArrowTransactionHandle& p); + +} // namespace facebook::presto::protocol::arrow_flight +{{#.}} +{{#hinc}} +{{&hinc}} +{{/hinc}} +{{^hinc}} +{{#struct}} +namespace facebook::presto::protocol::arrow_flight { + struct {{class_name}} {{#super_class}}: public {{super_class}}{{/super_class}}{ + {{#fields}} + {{#field_local}}{{#optional}}std::shared_ptr<{{/optional}}{{&field_text}}{{#optional}}>{{/optional}} {{&field_name}} = {};{{/field_local}} + {{/fields}} + + {{#super_class}} + {{class_name}}() noexcept; + {{/super_class}} + }; + void to_json(json& j, const {{class_name}}& p); + void from_json(const json& j, {{class_name}}& p); +} +{{/struct}} +{{#enum}} +namespace facebook::presto::protocol::arrow_flight { + enum class {{class_name}} { + {{#elements}} + {{&element}}{{^_last}},{{/_last}} + {{/elements}} + }; + extern void to_json(json& j, const {{class_name}}& e); + extern void from_json(const json& j, {{class_name}}& e); +} +{{/enum}} +{{/hinc}} +{{/.}} diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.cpp b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.cpp new file mode 100644 index 0000000000000..e5b5cf2f9ae3b --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.cpp @@ -0,0 +1,215 @@ +// DO NOT EDIT : This file is generated by chevron +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// presto_protocol.prolog.cpp +// + +// This file is generated DO NOT EDIT @generated + +#include + +#include "presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.h" +using namespace std::string_literals; + +// ArrowTransactionHandle is special since +// the corresponding class in Java is an enum. + +namespace facebook::presto::protocol::arrow_flight { + +void to_json(json& j, const ArrowTransactionHandle& p) { + j = json::array(); + j.push_back(p._type); + j.push_back(p.instance); +} + +void from_json(const json& j, ArrowTransactionHandle& p) { + j[0].get_to(p._type); + j[1].get_to(p.instance); +} +} // namespace facebook::presto::protocol::arrow_flight +namespace facebook::presto::protocol::arrow_flight { +ArrowColumnHandle::ArrowColumnHandle() noexcept { + _type = "arrow-flight"; +} + +void to_json(json& j, const ArrowColumnHandle& p) { + j = json::object(); + j["@type"] = "arrow-flight"; + to_json_key( + j, + "columnName", + p.columnName, + "ArrowColumnHandle", + "String", + "columnName"); + to_json_key( + j, "columnType", p.columnType, "ArrowColumnHandle", "Type", "columnType"); +} + +void from_json(const json& j, ArrowColumnHandle& p) { + p._type = j["@type"]; + from_json_key( + j, + "columnName", + p.columnName, + "ArrowColumnHandle", + "String", + "columnName"); + from_json_key( + j, "columnType", p.columnType, "ArrowColumnHandle", "Type", "columnType"); +} +} // namespace facebook::presto::protocol::arrow_flight +namespace facebook::presto::protocol::arrow_flight { +ArrowSplit::ArrowSplit() noexcept { + _type = "arrow-flight"; +} + +void to_json(json& j, const ArrowSplit& p) { + j = json::object(); + j["@type"] = "arrow-flight"; + to_json_key( + j, "schemaName", p.schemaName, "ArrowSplit", "String", "schemaName"); + to_json_key(j, "tableName", p.tableName, "ArrowSplit", "String", "tableName"); + to_json_key( + j, + "flightEndpointBytes", + p.flightEndpointBytes, + "ArrowSplit", + "String", + "flightEndpointBytes"); +} + +void from_json(const json& j, ArrowSplit& p) { + p._type = j["@type"]; + from_json_key( + j, "schemaName", p.schemaName, "ArrowSplit", "String", "schemaName"); + from_json_key( + j, "tableName", p.tableName, "ArrowSplit", "String", "tableName"); + from_json_key( + j, + "flightEndpointBytes", + p.flightEndpointBytes, + "ArrowSplit", + "String", + "flightEndpointBytes"); +} +} // namespace facebook::presto::protocol::arrow_flight +namespace facebook::presto::protocol::arrow_flight { +ArrowTableHandle::ArrowTableHandle() noexcept { + _type = "arrow-flight"; +} + +void to_json(json& j, const ArrowTableHandle& p) { + j = json::object(); + j["@type"] = "arrow-flight"; + to_json_key(j, "schema", p.schema, "ArrowTableHandle", "String", "schema"); + to_json_key(j, "table", p.table, "ArrowTableHandle", "String", "table"); +} + +void from_json(const json& j, ArrowTableHandle& p) { + p._type = j["@type"]; + from_json_key(j, "schema", p.schema, "ArrowTableHandle", "String", "schema"); + from_json_key(j, "table", p.table, "ArrowTableHandle", "String", "table"); +} +} // namespace facebook::presto::protocol::arrow_flight +namespace facebook::presto::protocol::arrow_flight { +void to_json(json& j, const std::shared_ptr& p) { + if (p == nullptr) { + return; + } + String type = p->_type; + + if (type == "arrow-flight") { + j = *std::static_pointer_cast(p); + return; + } + + throw TypeError(type + " no abstract type ColumnHandle "); +} + +void from_json(const json& j, std::shared_ptr& p) { + String type; + try { + type = p->getSubclassKey(j); + } catch (json::parse_error& e) { + throw ParseError(std::string(e.what()) + " ColumnHandle ColumnHandle"); + } + + if (type == "arrow-flight") { + std::shared_ptr k = + std::make_shared(); + j.get_to(*k); + p = std::static_pointer_cast(k); + return; + } + + throw TypeError(type + " no abstract type ColumnHandle "); +} +} // namespace facebook::presto::protocol::arrow_flight +namespace facebook::presto::protocol::arrow_flight { +ArrowTableLayoutHandle::ArrowTableLayoutHandle() noexcept { + _type = "arrow-flight"; +} + +void to_json(json& j, const ArrowTableLayoutHandle& p) { + j = json::object(); + j["@type"] = "arrow-flight"; + to_json_key( + j, + "table", + p.table, + "ArrowTableLayoutHandle", + "ArrowTableHandle", + "table"); + to_json_key( + j, + "columnHandles", + p.columnHandles, + "ArrowTableLayoutHandle", + "List", + "columnHandles"); + to_json_key( + j, + "tupleDomain", + p.tupleDomain, + "ArrowTableLayoutHandle", + "TupleDomain>", + "tupleDomain"); +} + +void from_json(const json& j, ArrowTableLayoutHandle& p) { + p._type = j["@type"]; + from_json_key( + j, + "table", + p.table, + "ArrowTableLayoutHandle", + "ArrowTableHandle", + "table"); + from_json_key( + j, + "columnHandles", + p.columnHandles, + "ArrowTableLayoutHandle", + "List", + "columnHandles"); + from_json_key( + j, + "tupleDomain", + p.tupleDomain, + "ArrowTableLayoutHandle", + "TupleDomain>", + "tupleDomain"); +} +} // namespace facebook::presto::protocol::arrow_flight diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.h b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.h new file mode 100644 index 0000000000000..2a9cb81d00b47 --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.h @@ -0,0 +1,82 @@ +// DO NOT EDIT : This file is generated by chevron +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +// This file is generated DO NOT EDIT @generated + +#include +#include +#include +#include + +#include "presto_cpp/external/json/nlohmann/json.hpp" +#include "presto_cpp/presto_protocol/core/presto_protocol_core.h" + +// ArrowTransactionHandle is special since +// the corresponding class in Java is an enum. + +namespace facebook::presto::protocol::arrow_flight { + +struct ArrowTransactionHandle : public ConnectorTransactionHandle { + String instance = {}; +}; + +void to_json(json& j, const ArrowTransactionHandle& p); + +void from_json(const json& j, ArrowTransactionHandle& p); + +} // namespace facebook::presto::protocol::arrow_flight +namespace facebook::presto::protocol::arrow_flight { +struct ArrowColumnHandle : public ColumnHandle { + String columnName = {}; + Type columnType = {}; + + ArrowColumnHandle() noexcept; +}; +void to_json(json& j, const ArrowColumnHandle& p); +void from_json(const json& j, ArrowColumnHandle& p); +} // namespace facebook::presto::protocol::arrow_flight +namespace facebook::presto::protocol::arrow_flight { +struct ArrowSplit : public ConnectorSplit { + String schemaName = {}; + String tableName = {}; + String flightEndpointBytes = {}; + + ArrowSplit() noexcept; +}; +void to_json(json& j, const ArrowSplit& p); +void from_json(const json& j, ArrowSplit& p); +} // namespace facebook::presto::protocol::arrow_flight +namespace facebook::presto::protocol::arrow_flight { +struct ArrowTableHandle : public ConnectorTableHandle { + String schema = {}; + String table = {}; + + ArrowTableHandle() noexcept; +}; +void to_json(json& j, const ArrowTableHandle& p); +void from_json(const json& j, ArrowTableHandle& p); +} // namespace facebook::presto::protocol::arrow_flight +namespace facebook::presto::protocol::arrow_flight { +struct ArrowTableLayoutHandle : public ConnectorTableLayoutHandle { + ArrowTableHandle table = {}; + List columnHandles = {}; + TupleDomain> tupleDomain = {}; + + ArrowTableLayoutHandle() noexcept; +}; +void to_json(json& j, const ArrowTableLayoutHandle& p); +void from_json(const json& j, ArrowTableLayoutHandle& p); +} // namespace facebook::presto::protocol::arrow_flight diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.yml b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.yml new file mode 100644 index 0000000000000..f34f6068eb777 --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.yml @@ -0,0 +1,40 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +AbstractClasses: + ColumnHandle: + super: JsonEncodedSubclass + comparable: true + subclasses: + - { name: ArrowColumnHandle, key: arrow-flight } + + ConnectorTableHandle: + super: JsonEncodedSubclass + subclasses: + - { name: ArrowTableHandle, key: arrow-flight } + + ConnectorTableLayoutHandle: + super: JsonEncodedSubclass + subclasses: + - { name: ArrowTableLayoutHandle, key: arrow-flight } + + ConnectorSplit: + super: JsonEncodedSubclass + subclasses: + - { name: ArrowSplit, key: arrow-flight } + +JavaClasses: + - presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowColumnHandle.java + - presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableHandle.java + - presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableLayoutHandle.java + - presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplit.java diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/special/ArrowTransactionHandle.cpp.inc b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/special/ArrowTransactionHandle.cpp.inc new file mode 100644 index 0000000000000..a93325f5b154a --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/special/ArrowTransactionHandle.cpp.inc @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// ArrowTransactionHandle is special since +// the corresponding class in Java is an enum. + +namespace facebook::presto::protocol::arrow_flight { + +void to_json(json& j, const ArrowTransactionHandle& p) { + j = json::array(); + j.push_back(p._type); + j.push_back(p.instance); +} + +void from_json(const json& j, ArrowTransactionHandle& p) { + j[0].get_to(p._type); + j[1].get_to(p.instance); +} +} // namespace facebook::presto::protocol::arrow_flight diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/special/ArrowTransactionHandle.hpp.inc b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/special/ArrowTransactionHandle.hpp.inc new file mode 100644 index 0000000000000..dc573ca2e68cf --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/special/ArrowTransactionHandle.hpp.inc @@ -0,0 +1,28 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// ArrowTransactionHandle is special since +// the corresponding class in Java is an enum. + +namespace facebook::presto::protocol::arrow_flight { + +struct ArrowTransactionHandle : public ConnectorTransactionHandle { + String instance = {}; +}; + +void to_json(json& j, const ArrowTransactionHandle& p); + +void from_json(const json& j, ArrowTransactionHandle& p); + +} // namespace facebook::presto::protocol::arrow_flight diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp index 63baf0054e053..a26c9b0a70d98 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp @@ -1090,6 +1090,7 @@ void from_json(const json& j, std::shared_ptr& p) { */ // dependency TpchTransactionHandle +// dependency ArrowTransactionHandle namespace facebook::presto::protocol { void to_json(json& j, const std::shared_ptr& p) { diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml index 5f1180ae2e191..e280d75036f76 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml @@ -54,6 +54,7 @@ AbstractClasses: - { name: IcebergColumnHandle, key: hive-iceberg } - { name: TpchColumnHandle, key: tpch } - { name: SystemColumnHandle, key: $system@system } + - { name: ArrowColumnHandle, key: arrow-flight } ConnectorPartitioningHandle: super: JsonEncodedSubclass @@ -69,6 +70,7 @@ AbstractClasses: - { name: IcebergTableHandle, key: hive-iceberg } - { name: TpchTableHandle, key: tpch } - { name: SystemTableHandle, key: $system@system } + - { name: ArrowTableHandle, key: arrow-flight } ConnectorOutputTableHandle: super: JsonEncodedSubclass @@ -100,6 +102,7 @@ AbstractClasses: - { name: IcebergTableLayoutHandle, key: hive-iceberg } - { name: TpchTableLayoutHandle, key: tpch } - { name: SystemTableLayoutHandle, key: $system@system } + - { name: ArrowTableLayoutHandle, key: arrow-flight } ConnectorMetadataUpdateHandle: super: JsonEncodedSubclass @@ -115,6 +118,7 @@ AbstractClasses: - { name: RemoteSplit, key: $remote } - { name: EmptySplit, key: $empty } - { name: SystemSplit, key: $system@system } + - { name: ArrowSplit, key: arrow-flight } ConnectorHistogram: super: JsonEncodedSubclass diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorTransactionHandle.cpp.inc b/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorTransactionHandle.cpp.inc index 8ec2a94e84bd9..1dfb17e4a908f 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorTransactionHandle.cpp.inc +++ b/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorTransactionHandle.cpp.inc @@ -13,6 +13,7 @@ */ // dependency TpchTransactionHandle +// dependency ArrowTransactionHandle namespace facebook::presto::protocol { void to_json(json& j, const std::shared_ptr& p) { diff --git a/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.cpp b/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.cpp index c15084817a434..24f24f27f87a3 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.cpp +++ b/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.cpp @@ -15,6 +15,7 @@ // DEPRECATED: This file is deprecated and will be removed in future versions. +#include "presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.cpp" #include "presto_cpp/presto_protocol/connector/hive/presto_protocol_hive.cpp" #include "presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.cpp" #include "presto_cpp/presto_protocol/connector/tpch/presto_protocol_tpch.cpp" diff --git a/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.h b/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.h index dd94975e3760d..c43ec92629f44 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.h +++ b/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.h @@ -16,6 +16,7 @@ // DEPRECATED: This file is deprecated and will be removed in future versions. +#include "presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.h" #include "presto_cpp/presto_protocol/connector/hive/presto_protocol_hive.h" #include "presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.h" #include "presto_cpp/presto_protocol/connector/tpch/presto_protocol_tpch.h" diff --git a/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.yml b/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.yml index fc0b86ddfc046..5a5f9d08dfe6c 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.yml +++ b/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.yml @@ -53,6 +53,7 @@ AbstractClasses: - { name: IcebergColumnHandle, key: hive-iceberg } - { name: TpchColumnHandle, key: tpch } - { name: SystemColumnHandle, key: $system@system } + - { name: ArrowColumnHandle, key: arrow-flight } ConnectorPartitioningHandle: super: JsonEncodedSubclass @@ -68,6 +69,7 @@ AbstractClasses: - { name: IcebergTableHandle, key: hive-iceberg } - { name: TpchTableHandle, key: tpch } - { name: SystemTableHandle, key: $system@system } + - { name: ArrowTableHandle, key: arrow-flight } ConnectorOutputTableHandle: super: JsonEncodedSubclass @@ -99,6 +101,7 @@ AbstractClasses: - { name: IcebergTableLayoutHandle, key: hive-iceberg } - { name: TpchTableLayoutHandle, key: tpch } - { name: SystemTableLayoutHandle, key: $system@system } + - { name: ArrowTableLayoutHandle, key: arrow-flight } ConnectorMetadataUpdateHandle: super: JsonEncodedSubclass @@ -114,6 +117,7 @@ AbstractClasses: - { name: RemoteSplit, key: $remote } - { name: EmptySplit, key: $empty } - { name: SystemSplit, key: $system@system } + - { name: ArrowSplit, key: arrow-flight } ConnectorHistogram: super: JsonEncodedSubclass @@ -379,3 +383,7 @@ JavaClasses: - presto-function-namespace-managers-common/src/main/java/com/facebook/presto/functionNamespace/JsonBasedUdfFunctionMetadata.java - presto-spi/src/main/java/com/facebook/presto/spi/plan/DeleteNode.java - presto-spi/src/main/java/com/facebook/presto/spi/plan/BaseInputDistribution.java + - presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowColumnHandle.java + - presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableHandle.java + - presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableLayoutHandle.java + - presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplit.java diff --git a/presto-native-execution/scripts/setup-adapters.sh b/presto-native-execution/scripts/setup-adapters.sh index 6c36424ebf90c..3cb965fe71781 100755 --- a/presto-native-execution/scripts/setup-adapters.sh +++ b/presto-native-execution/scripts/setup-adapters.sh @@ -35,15 +35,75 @@ function install_prometheus_cpp { cmake_install -DBUILD_SHARED_LIBS=ON -DENABLE_PUSH=OFF -DENABLE_COMPRESSION=OFF } +function install_abseil { + # abseil-cpp + github_checkout abseil/abseil-cpp 20240116.2 --depth 1 + cmake_install \ + -DABSL_BUILD_TESTING=OFF \ + -DCMAKE_CXX_STANDARD=17 \ + -DABSL_PROPAGATE_CXX_STD=ON \ + -DABSL_ENABLE_INSTALL=ON +} + +function install_grpc { + # grpc + github_checkout grpc/grpc v1.48.1 --depth 1 + cmake_install \ + -DgRPC_BUILD_TESTS=OFF \ + -DgRPC_ABSL_PROVIDER=package \ + -DgRPC_ZLIB_PROVIDER=package \ + -DgRPC_CARES_PROVIDER=package \ + -DgRPC_RE2_PROVIDER=package \ + -DgRPC_SSL_PROVIDER=package \ + -DgRPC_PROTOBUF_PROVIDER=package \ + -DgRPC_INSTALL=ON +} + +function install_arrow_flight { + ARROW_VERSION="${ARROW_VERSION:-15.0.0}" + if [[ "$OSTYPE" == "linux-gnu"* ]]; then + export INSTALL_PREFIX=${INSTALL_PREFIX:-"/usr/local"} + LINUX_DISTRIBUTION=$(. /etc/os-release && echo ${ID}) + if [[ "$LINUX_DISTRIBUTION" == "ubuntu" || "$LINUX_DISTRIBUTION" == "debian" ]]; then + SUDO="${SUDO:-"sudo --preserve-env"}" + ${SUDO} apt install -y libc-ares-dev + ${SUDO} ldconfig -v 2>/dev/null | grep "${INSTALL_PREFIX}/lib" || \ + echo "${INSTALL_PREFIX}/lib" | ${SUDO} tee /etc/ld.so.conf.d/local-libraries.conf > /dev/null \ + && ${SUDO} ldconfig + else + dnf -y install c-ares-devel + ldconfig -v 2>/dev/null | grep "${INSTALL_PREFIX}/lib" || \ + echo "${INSTALL_PREFIX}/lib" | tee /etc/ld.so.conf.d/local-libraries.conf > /dev/null \ + && ldconfig + fi + else + # The installation script for the Arrow Flight connector currently works only on Linux distributions. + return 0 + fi + + install_abseil + install_grpc + + # NOTE: benchmarks are on due to a compilation error with v15.0.0, once updated that can be removed + # see https://github.com/apache/arrow/issues/41617 + wget_and_untar https://github.com/apache/arrow/archive/apache-arrow-${ARROW_VERSION}.tar.gz arrow + cmake_install_dir arrow/cpp \ + -DARROW_FLIGHT=ON \ + -DARROW_BUILD_BENCHMARKS=ON \ + -DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX} +} + cd "${DEPENDENCY_DIR}" || exit install_jwt=0 install_prometheus_cpp=0 +install_arrow_flight=0 if [ "$#" -eq 0 ]; then # Install all adapters by default install_jwt=1 install_prometheus_cpp=1 + install_arrow_flight=1 fi while [[ $# -gt 0 ]]; do @@ -56,6 +116,10 @@ while [[ $# -gt 0 ]]; do install_prometheus_cpp=1; shift ;; + arrow_flight) + install_arrow_flight=1; + shift + ;; *) echo "ERROR: Unknown option $1! will be ignored!" shift @@ -72,6 +136,10 @@ if [ $install_prometheus_cpp -eq 1 ]; then install_prometheus_cpp fi +if [ $install_arrow_flight -eq 1 ]; then + install_arrow_flight +fi + _ret=$? if [ $_ret -eq 0 ] ; then echo "All deps for Presto adapters installed!"