diff --git a/presto-native-execution/CMakeLists.txt b/presto-native-execution/CMakeLists.txt index d5001dde70a74..6ac789b73e3a1 100644 --- a/presto-native-execution/CMakeLists.txt +++ b/presto-native-execution/CMakeLists.txt @@ -63,6 +63,8 @@ option(PRESTO_ENABLE_TESTING "Enable tests" ON) option(PRESTO_ENABLE_JWT "Enable JWT (JSON Web Token) authentication" OFF) +option(PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR "Enable Arrow Flight connector" OFF) + # Set all Velox options below add_compile_definitions(FOLLY_HAVE_INT128_T=1) diff --git a/presto-native-execution/Makefile b/presto-native-execution/Makefile index f3fb5f709f4d5..7cd37a714867c 100644 --- a/presto-native-execution/Makefile +++ b/presto-native-execution/Makefile @@ -45,6 +45,9 @@ endif ifneq ($(PRESTO_MEMORY_CHECKER_TYPE),) EXTRA_CMAKE_FLAGS += -DPRESTO_MEMORY_CHECKER_TYPE=$(PRESTO_MEMORY_CHECKER_TYPE) endif +ifneq ($(PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR),) + EXTRA_CMAKE_FLAGS += -DPRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR=$(PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR) +endif CMAKE_FLAGS := -DTREAT_WARNINGS_AS_ERRORS=${TREAT_WARNINGS_AS_ERRORS} CMAKE_FLAGS += -DENABLE_ALL_WARNINGS=${ENABLE_WALL} diff --git a/presto-native-execution/README.md b/presto-native-execution/README.md index cccebfcfb8d03..36a8659f6cbac 100644 --- a/presto-native-execution/README.md +++ b/presto-native-execution/README.md @@ -115,6 +115,15 @@ follow these steps: * For development, use `make debug` to build a non-optimized debug version. * Use `make unittest` to build and run tests. +#### Arrow Flight Connector +To enable Arrow Flight connector support, set the environment variable: +`PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR = "ON"`. + +The Arrow Flight connector requires the Arrow Flight library. You can install this dependency +by running the following script from the `presto/presto-native-execution` directory: + +`./scripts/setup-adapters.sh arrow_flight` + ### Makefile Targets A reminder of the available Makefile targets can be obtained using `make help` ``` diff --git a/presto-native-execution/presto_cpp/main/CMakeLists.txt b/presto-native-execution/presto_cpp/main/CMakeLists.txt index 8cac276cb185b..acafaa3540dcb 100644 --- a/presto-native-execution/presto_cpp/main/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/CMakeLists.txt @@ -14,6 +14,7 @@ add_subdirectory(types) add_subdirectory(http) add_subdirectory(common) add_subdirectory(thrift) +add_subdirectory(connectors) add_library( presto_server_lib @@ -48,6 +49,7 @@ target_link_libraries( presto_common presto_exception presto_function_metadata + presto_connector presto_http presto_operators velox_aggregates diff --git a/presto-native-execution/presto_cpp/main/PrestoServer.cpp b/presto-native-execution/presto_cpp/main/PrestoServer.cpp index f1c77215b1773..100acdeb4fa6d 100644 --- a/presto-native-execution/presto_cpp/main/PrestoServer.cpp +++ b/presto-native-execution/presto_cpp/main/PrestoServer.cpp @@ -26,6 +26,7 @@ #include "presto_cpp/main/common/ConfigReader.h" #include "presto_cpp/main/common/Counters.h" #include "presto_cpp/main/common/Utils.h" +#include "presto_cpp/main/connectors/ConnectorRegistration.h" #include "presto_cpp/main/http/HttpConstants.h" #include "presto_cpp/main/http/filters/AccessLogFilter.h" #include "presto_cpp/main/http/filters/HttpEndpointLatencyFilter.h" @@ -252,8 +253,9 @@ void PrestoServer::run() { // Register Velox connector factory for iceberg. // The iceberg catalog is handled by the hive connector factory. - connector::registerConnectorFactory( - std::make_shared("iceberg")); + velox::connector::registerConnectorFactory( + std::make_shared( + "iceberg")); registerPrestoToVeloxConnector( std::make_unique("hive")); @@ -275,6 +277,8 @@ void PrestoServer::run() { registerPrestoToVeloxConnector( std::make_unique("$system@system")); + presto::connector::registerAllPrestoConnectors(); + initializeVeloxMemory(); initializeThreadPools(); @@ -1108,18 +1112,18 @@ PrestoServer::getAdditionalHttpServerFilters() { void PrestoServer::registerConnectorFactories() { // These checks for connector factories can be removed after we remove the // registrations from the Velox library. - if (!connector::hasConnectorFactory( - connector::hive::HiveConnectorFactory::kHiveConnectorName)) { - connector::registerConnectorFactory( - std::make_shared()); - connector::registerConnectorFactory( - std::make_shared( + if (!velox::connector::hasConnectorFactory( + velox::connector::hive::HiveConnectorFactory::kHiveConnectorName)) { + velox::connector::registerConnectorFactory( + std::make_shared()); + velox::connector::registerConnectorFactory( + std::make_shared( kHiveHadoop2ConnectorName)); } - if (!connector::hasConnectorFactory( - connector::tpch::TpchConnectorFactory::kTpchConnectorName)) { - connector::registerConnectorFactory( - std::make_shared()); + if (!velox::connector::hasConnectorFactory( + velox::connector::tpch::TpchConnectorFactory::kTpchConnectorName)) { + velox::connector::registerConnectorFactory( + std::make_shared()); } } diff --git a/presto-native-execution/presto_cpp/main/SystemConnector.cpp b/presto-native-execution/presto_cpp/main/SystemConnector.cpp index 9de215e014cac..12a49ad4d8285 100644 --- a/presto-native-execution/presto_cpp/main/SystemConnector.cpp +++ b/presto-native-execution/presto_cpp/main/SystemConnector.cpp @@ -350,7 +350,8 @@ std::optional SystemDataSource::next( std::unique_ptr SystemPrestoToVeloxConnector::toVeloxSplit( const protocol::ConnectorId& catalogId, - const protocol::ConnectorSplit* const connectorSplit) const { + const protocol::ConnectorSplit* const connectorSplit, + const std::map& extraCredentials) const { auto systemSplit = dynamic_cast(connectorSplit); VELOX_CHECK_NOT_NULL( systemSplit, "Unexpected split type {}", connectorSplit->_type); diff --git a/presto-native-execution/presto_cpp/main/SystemConnector.h b/presto-native-execution/presto_cpp/main/SystemConnector.h index b467cf25676b6..21ecab380810d 100644 --- a/presto-native-execution/presto_cpp/main/SystemConnector.h +++ b/presto-native-execution/presto_cpp/main/SystemConnector.h @@ -184,7 +184,9 @@ class SystemPrestoToVeloxConnector final : public PrestoToVeloxConnector { std::unique_ptr toVeloxSplit( const protocol::ConnectorId& catalogId, - const protocol::ConnectorSplit* connectorSplit) const final; + const protocol::ConnectorSplit* connectorSplit, + const std::map& extraCredentials = {}) + const final; std::unique_ptr toVeloxColumnHandle( const protocol::ColumnHandle* column, diff --git a/presto-native-execution/presto_cpp/main/TaskManager.cpp b/presto-native-execution/presto_cpp/main/TaskManager.cpp index afab868c95a26..93e6d81399bbf 100644 --- a/presto-native-execution/presto_cpp/main/TaskManager.cpp +++ b/presto-native-execution/presto_cpp/main/TaskManager.cpp @@ -437,6 +437,7 @@ std::unique_ptr TaskManager::createOrUpdateTask( planFragment, updateRequest.sources, updateRequest.outputIds, + updateRequest.extraCredentials, std::move(queryCtx), startProcessCpuTime); } @@ -456,6 +457,7 @@ std::unique_ptr TaskManager::createOrUpdateBatchTask( planFragment, updateRequest.sources, updateRequest.outputIds, + updateRequest.extraCredentials, std::move(queryCtx), startProcessCpuTime); } @@ -465,6 +467,7 @@ std::unique_ptr TaskManager::createOrUpdateTaskImpl( const velox::core::PlanFragment& planFragment, const std::vector& sources, const protocol::OutputBuffers& outputBuffers, + const std::map& extraCredentials, std::shared_ptr queryCtx, long startProcessCpuTime) { std::shared_ptr execTask; @@ -565,7 +568,7 @@ std::unique_ptr TaskManager::createOrUpdateTaskImpl( // Keep track of the max sequence for this batch of splits. long maxSplitSequenceId{-1}; for (const auto& protocolSplit : source.splits) { - auto split = toVeloxSplit(protocolSplit); + auto split = toVeloxSplit(protocolSplit, extraCredentials); if (split.hasConnectorSplit()) { maxSplitSequenceId = std::max(maxSplitSequenceId, protocolSplit.sequenceId); diff --git a/presto-native-execution/presto_cpp/main/TaskManager.h b/presto-native-execution/presto_cpp/main/TaskManager.h index 5453a188309fe..1fddbbfd7a161 100644 --- a/presto-native-execution/presto_cpp/main/TaskManager.h +++ b/presto-native-execution/presto_cpp/main/TaskManager.h @@ -179,6 +179,7 @@ class TaskManager { const velox::core::PlanFragment& planFragment, const std::vector& sources, const protocol::OutputBuffers& outputBuffers, + const std::map& extraCredentials, std::shared_ptr queryCtx, long startProcessCpuTime); diff --git a/presto-native-execution/presto_cpp/main/connectors/CMakeLists.txt b/presto-native-execution/presto_cpp/main/connectors/CMakeLists.txt new file mode 100644 index 0000000000000..5e4e2b04e5e02 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/CMakeLists.txt @@ -0,0 +1,20 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +if(PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR) + add_subdirectory(arrow_flight) +endif() + +add_library(presto_connector ConnectorRegistration.cpp) + +if(PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR) + target_link_libraries(presto_connector presto_flight_connector) +endif() diff --git a/presto-native-execution/presto_cpp/main/connectors/ConnectorRegistration.cpp b/presto-native-execution/presto_cpp/main/connectors/ConnectorRegistration.cpp new file mode 100644 index 0000000000000..ac2cc164c12ae --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/ConnectorRegistration.cpp @@ -0,0 +1,40 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/ConnectorRegistration.h" + +#ifdef PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h" +#include "presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.h" +#endif + +namespace facebook::presto::connector { + +void registerAllPrestoConnectors() { +#ifdef PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR + registerPrestoToVeloxConnector( + std::make_unique< + presto::connector::arrow_flight::ArrowPrestoToVeloxConnector>( + "arrow-flight")); + + if (!velox::connector::hasConnectorFactory( + presto::connector::arrow_flight::ArrowFlightConnectorFactory:: + kArrowFlightConnectorName)) { + velox::connector::registerConnectorFactory( + std::make_shared< + presto::connector::arrow_flight::ArrowFlightConnectorFactory>()); + } +#endif +} + +} // namespace facebook::presto::connector diff --git a/presto-native-execution/presto_cpp/main/connectors/ConnectorRegistration.h b/presto-native-execution/presto_cpp/main/connectors/ConnectorRegistration.h new file mode 100644 index 0000000000000..187362876247f --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/ConnectorRegistration.h @@ -0,0 +1,20 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +namespace facebook::presto::connector { + +void registerAllPrestoConnectors(); + +} // namespace facebook::presto::connector diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.cpp new file mode 100644 index 0000000000000..cb8bc885d583a --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.cpp @@ -0,0 +1,187 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h" +#include "arrow/c/abi.h" +#include "arrow/c/bridge.h" +#include "presto_cpp/main/common/ConfigReader.h" +#include "presto_cpp/main/connectors/arrow_flight/Macros.h" +#include "velox/vector/arrow/Bridge.h" + +namespace facebook::presto::connector::arrow_flight { + +using namespace arrow::flight; +using namespace velox; +using namespace velox::connector; + +// Wrapper for CallOptions which does not add any member variables, +// but provides a write-only interface for adding call headers. +class CallOptionsAddHeaders : public FlightCallOptions, public AddCallHeaders { + public: + void AddHeader(const std::string& key, const std::string& value) override { + headers.emplace_back(key, value); + } +}; + +std::optional ArrowFlightConnector::getDefaultLocation( + const std::shared_ptr& config) { + auto defaultHost = config->defaultServerHostname(); + auto defaultPort = config->defaultServerPort(); + if (!defaultHost.has_value() || !defaultPort.has_value()) { + return std::nullopt; + } + + bool defaultSslEnabled = config->defaultServerSslEnabled(); + AFC_RETURN_OR_RAISE( + defaultSslEnabled + ? Location::ForGrpcTls(defaultHost.value(), defaultPort.value()) + : Location::ForGrpcTcp(defaultHost.value(), defaultPort.value())); +} + +std::shared_ptr +ArrowFlightConnector::initClientOpts( + const std::shared_ptr& config) { + auto clientOpts = std::make_shared(); + clientOpts->disable_server_verification = !config->serverVerify(); + + auto certPath = config->serverSslCertificate(); + if (certPath.hasValue()) { + std::ifstream file(certPath.value()); + VELOX_CHECK(file.is_open(), "Could not open TLS certificate"); + std::string cert( + (std::istreambuf_iterator(file)), + (std::istreambuf_iterator())); + clientOpts->tls_root_certs = cert; + } + + return clientOpts; +} + +FlightDataSource::FlightDataSource( + const RowTypePtr& outputType, + const std::unordered_map>& + columnHandles, + std::shared_ptr authenticator, + memory::MemoryPool* pool, + const std::shared_ptr& flightConfig, + const std::shared_ptr& clientOpts, + const std::optional defaultLocation) + : outputType_{outputType}, + authenticator_{authenticator}, + pool_{pool}, + flightConfig_{flightConfig}, + clientOpts_{clientOpts}, + defaultLocation_{defaultLocation} { + // columnMapping_ contains the real column names in the expected order. + // This is later used by projectOutputColumns to filter out unecessary + // columns from the fetched chunk. + columnMapping_.reserve(outputType_->size()); + + for (auto columnName : outputType_->names()) { + auto it = columnHandles.find(columnName); + VELOX_CHECK( + it != columnHandles.end(), + "missing columnHandle for column '{}'", + columnName); + + auto handle = std::dynamic_pointer_cast(it->second); + VELOX_CHECK_NOT_NULL( + handle, + "handle for column '{}' is not an FlightColumnHandle", + columnName); + + columnMapping_.push_back(handle->name()); + } +} + +void FlightDataSource::addSplit(std::shared_ptr split) { + auto flightSplit = std::dynamic_pointer_cast(split); + VELOX_CHECK(flightSplit, "FlightDataSource received wrong type of split"); + + auto& locs = flightSplit->locations; + Location loc; + if (locs.size() > 0) { + AFC_ASSIGN_OR_RAISE(loc, Location::Parse(locs[0])); + } else { + VELOX_CHECK( + defaultLocation_.has_value(), + "Split has empty Location list, but default host or port is missing"); + loc = defaultLocation_.value(); + } + + AFC_ASSIGN_OR_RAISE( + auto client, + FlightClient::Connect( + loc, clientOpts_ ? *clientOpts_ : FlightClientOptions{})); + + CallOptionsAddHeaders callOptsAddHeaders{}; + FlightCallOptions& callOpts = callOptsAddHeaders; + AddCallHeaders& headerWriter = callOptsAddHeaders; + authenticator_->authenticateClient( + client, flightSplit->extraCredentials, headerWriter); + + auto ticket = Ticket{flightSplit->ticket}; + auto readerResult = client->DoGet(callOpts, ticket); + VELOX_CHECK( + readerResult.ok(), + "Server replied with error: {}", + readerResult.status().message()); + currentReader_ = std::move(readerResult).ValueUnsafe(); +} + +std::optional FlightDataSource::next( + uint64_t size, + velox::ContinueFuture& /* unused */) { + VELOX_CHECK_NOT_NULL(currentReader_, "Missing split, call addSplit() first"); + + AFC_ASSIGN_OR_RAISE(auto chunk, currentReader_->Next()); + auto recordBatch = std::move(chunk).data; + + // Null values in the chunk indicates that the Flight stream is complete. + if (!recordBatch) { + currentReader_ = nullptr; + return nullptr; + } + + // Extract only required columns from the record batch as a velox RowVector. + auto output = projectOutputColumns(recordBatch); + + completedRows_ += output->size(); + completedBytes_ += output->inMemoryBytes(); + return output; +} + +RowVectorPtr FlightDataSource::projectOutputColumns( + const std::shared_ptr& input) { + std::vector children; + children.reserve(columnMapping_.size()); + + // Extract and convert desired columns in the correct order. + for (auto name : columnMapping_) { + auto column = input->GetColumnByName(name); + VELOX_CHECK_NOT_NULL(column, "column with name '{}' not found", name); + ArrowArray array; + ArrowSchema schema; + AFC_RAISE_NOT_OK(arrow::ExportArray(*column, &array, &schema)); + children.push_back(importFromArrowAsOwner(schema, array, pool_)); + } + + return std::make_shared( + pool_, + outputType_, + BufferPtr() /*nulls*/, + input->num_rows(), + std::move(children)); +} + +} // namespace facebook::presto::connector::arrow_flight diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h new file mode 100644 index 0000000000000..13462623ca34c --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h @@ -0,0 +1,200 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "arrow/flight/api.h" +#include "presto_cpp/main/connectors/arrow_flight/FlightConfig.h" +#include "presto_cpp/main/connectors/arrow_flight/auth/Authenticator.h" +#include "velox/connectors/Connector.h" + +namespace facebook::presto::connector::arrow_flight { + +class FlightTableHandle : public velox::connector::ConnectorTableHandle { + public: + explicit FlightTableHandle(const std::string& connectorId) + : ConnectorTableHandle(connectorId) {} +}; + +struct FlightSplit : public velox::connector::ConnectorSplit { + /// @param connectorId + /// @param ticket Flight Ticket obtained from `GetFlightInfo` + /// @param locations Locations which can consume the ticket + /// @param extraCredentials Extra credentials for authentication + FlightSplit( + const std::string& connectorId, + const std::string& ticket, + const std::vector& locations = {}, + const std::map& extraCredentials = {}) + : ConnectorSplit(connectorId), + ticket(ticket), + locations(locations), + extraCredentials(extraCredentials) {} + + const std::string ticket; + const std::vector locations; + std::map extraCredentials; +}; + +class FlightColumnHandle : public velox::connector::ColumnHandle { + public: + FlightColumnHandle(const std::string& columnName) : columnName_(columnName) {} + + const std::string& name() { + return columnName_; + } + + private: + std::string columnName_; +}; + +class FlightDataSource : public velox::connector::DataSource { + public: + FlightDataSource( + const velox::RowTypePtr& outputType, + const std::unordered_map< + std::string, + std::shared_ptr>& columnHandles, + std::shared_ptr authenticator, + velox::memory::MemoryPool* pool, + const std::shared_ptr& flightConfig, + const std::shared_ptr& clientOpts, + const std::optional defaultLocation = + std::nullopt); + + void addSplit( + std::shared_ptr split) override; + + std::optional next( + uint64_t size, + velox::ContinueFuture& /* unused */) override; + + void addDynamicFilter( + velox::column_index_t outputChannel, + const std::shared_ptr& filter) override { + VELOX_NYI("This connector doesn't support dynamic filters"); + } + + uint64_t getCompletedBytes() override { + return completedBytes_; + } + + uint64_t getCompletedRows() override { + return completedRows_; + } + + std::unordered_map runtimeStats() + override { + return {}; + } + + private: + /// Convert an arrow record batch to Velox RowVector. + /// Process only those columns that are present in outputType_. + velox::RowVectorPtr projectOutputColumns( + const std::shared_ptr& input); + + velox::RowTypePtr outputType_; + std::vector columnMapping_; + std::unique_ptr currentReader_; + uint64_t completedRows_ = 0; + uint64_t completedBytes_ = 0; + std::shared_ptr authenticator_; + velox::memory::MemoryPool* const pool_; + const std::shared_ptr flightConfig_; + const std::shared_ptr clientOpts_; + const std::optional defaultLocation_; +}; + +class ArrowFlightConnector : public velox::connector::Connector { + public: + explicit ArrowFlightConnector( + const std::string& id, + std::shared_ptr config, + const char* authenticatorName = nullptr) + : Connector(id), + flightConfig_(std::make_shared(config)), + clientOpts_(initClientOpts(flightConfig_)), + defaultLocation_(getDefaultLocation(flightConfig_)), + authenticator_(auth::getAuthenticatorFactory( + authenticatorName + ? authenticatorName + : flightConfig_->authenticatorName()) + ->newAuthenticator(config)) {} + + std::unique_ptr createDataSource( + const velox::RowTypePtr& outputType, + const std::shared_ptr& + tableHandle, + const std::unordered_map< + std::string, + std::shared_ptr>& columnHandles, + velox::connector::ConnectorQueryCtx* ctx) override { + return std::make_unique( + outputType, + columnHandles, + authenticator_, + ctx->memoryPool(), + flightConfig_, + clientOpts_, + defaultLocation_); + } + + std::unique_ptr createDataSink( + velox::RowTypePtr inputType, + std::shared_ptr + connectorInsertTableHandle, + velox::connector::ConnectorQueryCtx* connectorQueryCtx, + velox::connector::CommitStrategy commitStrategy) override { + VELOX_NYI("Flight connector does not support DataSink"); + } + + private: + // Returns the default location specified in the FlightConfig. + // Returns nullopt if either host or port is missing. + static std::optional getDefaultLocation( + const std::shared_ptr& config); + + static std::shared_ptr initClientOpts( + const std::shared_ptr& config); + + const std::shared_ptr flightConfig_; + const std::shared_ptr clientOpts_; + const std::optional defaultLocation_; + const std::shared_ptr authenticator_; +}; + +class ArrowFlightConnectorFactory : public velox::connector::ConnectorFactory { + public: + static constexpr const char* kArrowFlightConnectorName = "arrow-flight"; + + ArrowFlightConnectorFactory() : ConnectorFactory(kArrowFlightConnectorName) {} + + explicit ArrowFlightConnectorFactory( + const char* name, + const char* authenticatorName = nullptr) + : ConnectorFactory(name), authenticatorName_(authenticatorName) {} + + std::shared_ptr newConnector( + const std::string& id, + std::shared_ptr config, + folly::Executor* executor = nullptr) override { + return std::make_shared( + id, config, authenticatorName_); + } + + private: + const char* authenticatorName_{nullptr}; +}; + +} // namespace facebook::presto::connector::arrow_flight diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.cpp new file mode 100644 index 0000000000000..6a4883ebb2108 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.cpp @@ -0,0 +1,66 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.h" +#include "folly/base64.h" +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h" +#include "presto_cpp/presto_protocol/connector/arrow_flight/ArrowFlightConnectorProtocol.h" + +namespace facebook::presto::connector::arrow_flight { + +std::unique_ptr +ArrowPrestoToVeloxConnector::toVeloxSplit( + const protocol::ConnectorId& catalogId, + const protocol::ConnectorSplit* const connectorSplit, + const std::map& extraCredentials) const { + auto arrowSplit = + dynamic_cast(connectorSplit); + VELOX_CHECK_NOT_NULL( + arrowSplit, "Unexpected split type {}", connectorSplit->_type); + return std::make_unique( + catalogId, + folly::base64Decode(arrowSplit->ticket), + arrowSplit->locationUrls, + extraCredentials); +} + +std::unique_ptr +ArrowPrestoToVeloxConnector::toVeloxColumnHandle( + const protocol::ColumnHandle* column, + const TypeParser& typeParser) const { + auto arrowColumn = + dynamic_cast(column); + VELOX_CHECK_NOT_NULL( + arrowColumn, "Unexpected column handle type {}", column->_type); + return std::make_unique( + arrowColumn->columnName); +} + +std::unique_ptr +ArrowPrestoToVeloxConnector::toVeloxTableHandle( + const protocol::TableHandle& tableHandle, + const VeloxExprConverter& exprConverter, + const TypeParser& typeParser, + std::unordered_map< + std::string, + std::shared_ptr>& assignments) const { + return std::make_unique( + tableHandle.connectorId); +} + +std::unique_ptr +ArrowPrestoToVeloxConnector::createConnectorProtocol() const { + return std::make_unique(); +} + +} // namespace facebook::presto::connector::arrow_flight diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.h new file mode 100644 index 0000000000000..8c6b9937b7b9f --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.h @@ -0,0 +1,48 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "presto_cpp/main/types/PrestoToVeloxConnector.h" + +namespace facebook::presto::connector::arrow_flight { + +class ArrowPrestoToVeloxConnector final : public PrestoToVeloxConnector { + public: + explicit ArrowPrestoToVeloxConnector(std::string connectorName) + : PrestoToVeloxConnector(std::move(connectorName)) {} + + std::unique_ptr toVeloxSplit( + const protocol::ConnectorId& catalogId, + const protocol::ConnectorSplit* connectorSplit, + const std::map& extraCredentials = {}) + const final; + + std::unique_ptr toVeloxColumnHandle( + const protocol::ColumnHandle* column, + const TypeParser& typeParser) const final; + + std::unique_ptr toVeloxTableHandle( + const protocol::TableHandle& tableHandle, + const VeloxExprConverter& exprConverter, + const TypeParser& typeParser, + std::unordered_map< + std::string, + std::shared_ptr>& assignments) + const final; + + std::unique_ptr createConnectorProtocol() + const final; +}; + +} // namespace facebook::presto::connector::arrow_flight diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/CMakeLists.txt b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/CMakeLists.txt new file mode 100644 index 0000000000000..870481c34acd0 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/CMakeLists.txt @@ -0,0 +1,43 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +find_package(Arrow REQUIRED) +find_package(PkgConfig REQUIRED) +pkg_check_modules(ARROW_FLIGHT REQUIRED arrow-flight) + +if(NOT ARROW_FLIGHT_FOUND) + message(FATAL_ERROR "Arrow Flight package not found") +endif() + +set(ArrowFlight_FOUND TRUE) +set(ArrowFlight_INCLUDE_DIRS ${ARROW_FLIGHT_INCLUDE_DIRS}) +set(ArrowFlight_LIBRARIES ${ARROW_FLIGHT_LIBRARIES}) +include_directories(${ArrowFlight_INCLUDE_DIRS}) + +add_subdirectory(auth) + +add_library(presto_flight_connector_utils INTERFACE Macros.h) +target_link_libraries(presto_flight_connector_utils INTERFACE velox_exception) + +add_library( + presto_flight_connector OBJECT + ArrowFlightConnector.cpp ArrowPrestoToVeloxConnector.cpp FlightConfig.cpp) + +target_compile_definitions(presto_flight_connector + PUBLIC PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR) + +target_link_libraries( + presto_flight_connector velox_connector ${ArrowFlight_LIBRARIES} + presto_flight_connector_utils presto_flight_connector_auth presto_types) + +if(PRESTO_ENABLE_TESTING) + add_subdirectory(tests) +endif() diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/FlightConfig.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/FlightConfig.cpp new file mode 100644 index 0000000000000..4982c88001bb3 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/FlightConfig.cpp @@ -0,0 +1,44 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/FlightConfig.h" + +namespace facebook::presto::connector::arrow_flight { + +std::string FlightConfig::authenticatorName() { + return config_->get(kAuthenticatorName, "none"); +} + +std::optional FlightConfig::defaultServerHostname() { + return static_cast>( + config_->get(kDefaultServerHost)); +} + +std::optional FlightConfig::defaultServerPort() { + return static_cast>( + config_->get(kDefaultServerPort)); +} + +bool FlightConfig::defaultServerSslEnabled() { + return config_->get(kDefaultServerSslEnabled, false); +} + +bool FlightConfig::serverVerify() { + return config_->get(kServerVerify, true); +} + +folly::Optional FlightConfig::serverSslCertificate() { + return config_->get(kServerSslCertificate); +} + +} // namespace facebook::presto::connector::arrow_flight diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/FlightConfig.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/FlightConfig.h new file mode 100644 index 0000000000000..33722a85a76bf --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/FlightConfig.h @@ -0,0 +1,56 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/common/config/Config.h" + +namespace facebook::presto::connector::arrow_flight { + +class FlightConfig { + public: + explicit FlightConfig(std::shared_ptr config) + : config_{config} {} + + static constexpr const char* kAuthenticatorName = + "arrow-flight.authenticator.name"; + + static constexpr const char* kDefaultServerHost = "arrow-flight.server"; + + static constexpr const char* kDefaultServerPort = "arrow-flight.server.port"; + + static constexpr const char* kDefaultServerSslEnabled = + "arrow-flight.server-ssl-enabled"; + + static constexpr const char* kServerVerify = "arrow-flight.server.verify"; + + static constexpr const char* kServerSslCertificate = + "arrow-flight.server-ssl-certificate"; + + std::string authenticatorName(); + + std::optional defaultServerHostname(); + + std::optional defaultServerPort(); + + bool defaultServerSslEnabled(); + + bool serverVerify(); + + folly::Optional serverSslCertificate(); + + private: + std::shared_ptr config_; +}; + +} // namespace facebook::presto::connector::arrow_flight diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/Macros.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/Macros.h new file mode 100644 index 0000000000000..5ab725e582cc6 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/Macros.h @@ -0,0 +1,50 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/common/base/Exceptions.h" + +// Macros for dealing with arrow::Status and arrow::Result objects +// and converting them to velox exceptions. + +/// Raise a Velox exception if status is not OK. +/// Counterpart of ARROW_RETURN_NOT_OK. +#define AFC_RAISE_NOT_OK(status) \ + do { \ + ::arrow::Status __s = ::arrow::internal::GenericToStatus(status); \ + VELOX_CHECK(__s.ok(), __s.message()); \ + } while (false) + +#define AFC_ASSIGN_OR_RAISE_IMPL(result_name, lhs, rexpr) \ + auto&& result_name = (rexpr); \ + VELOX_CHECK((result_name).ok(), (result_name).status().message()); \ + lhs = std::move(result_name).ValueUnsafe(); + +/// Raise a Velox exception if expr doesn't return an OK result, +/// else unwrap the value and assign it to `lhs`. +/// `std::move`s its right hand operand. +/// Counterpart of ARROW_ASSIGN_OR_RAISE. +#define AFC_ASSIGN_OR_RAISE(lhs, rexpr) \ + AFC_ASSIGN_OR_RAISE_IMPL( \ + ARROW_ASSIGN_OR_RAISE_NAME(_error_or_value, __COUNTER__), lhs, rexpr); + +/// Raise a Velox exception if rexpr doesn't return an OK result, +/// else unwrap the value and return it. +/// `std::move`s its right hand operand. +#define AFC_RETURN_OR_RAISE(rexpr) \ + do { \ + auto&& __r = (rexpr); \ + VELOX_CHECK(__r.ok(), __r.status().message()); \ + return std::move(__r).ValueUnsafe(); \ + } while (false) diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/Authenticator.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/Authenticator.cpp new file mode 100644 index 0000000000000..0f6eb412d09d6 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/Authenticator.cpp @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/auth/Authenticator.h" +#include "velox/common/base/Exceptions.h" + +namespace facebook::presto::connector::arrow_flight::auth { + +static auto& authenticatorFactories() { + static std::unordered_map> + factories; + return factories; +} + +bool registerAuthenticatorFactory( + std::shared_ptr factory) { + bool ok = authenticatorFactories().insert({factory->name(), factory}).second; + VELOX_CHECK( + ok, + "Flight AuthenticatorFactory with name {} is already registered", + factory->name()); + return true; +} + +std::shared_ptr getAuthenticatorFactory( + const std::string& name) { + auto it = authenticatorFactories().find(name); + VELOX_CHECK( + it != authenticatorFactories().end(), + "Flight AuthenticatorFactory with name {} not registered", + name); + return it->second; +} + +AFC_REGISTER_AUTH_FACTORY(std::make_shared()) + +} // namespace facebook::presto::connector::arrow_flight::auth diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/Authenticator.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/Authenticator.h new file mode 100644 index 0000000000000..066132bfd9a5c --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/Authenticator.h @@ -0,0 +1,85 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "arrow/flight/api.h" +#include "velox/common/config/Config.h" + +namespace facebook::presto::connector::arrow_flight::auth { + +class Authenticator { + public: + /// @brief Override this method to define implementation-specific + /// authentication This could be through client->Authenticate, or + /// client->AuthenticateBasicToken or any other custom strategy + /// @param client the Flight client which is to be authenticated + /// @param extraCredentials extra credential data used for authentication + /// @param headerWriter write-only object used to set authentication headers + virtual void authenticateClient( + std::unique_ptr& client, + const std::map& extraCredentials, + arrow::flight::AddCallHeaders& headerWriter) = 0; +}; + +class AuthenticatorFactory { + public: + AuthenticatorFactory(std::string_view name) : name_{name} {} + + const std::string& name() const { + return name_; + } + + virtual std::shared_ptr newAuthenticator( + const std::shared_ptr config) = 0; + + private: + std::string name_; +}; + +bool registerAuthenticatorFactory( + std::shared_ptr factory); + +std::shared_ptr getAuthenticatorFactory( + const std::string& name); + +#define AFC_REGISTER_AUTH_FACTORY(factory) \ + namespace { \ + static bool FB_ANONYMOUS_VARIABLE(g_ConnectorFactory) = ::facebook::presto:: \ + connector::arrow_flight::auth::registerAuthenticatorFactory((factory)); \ + } + +class NoOpAuthenticator : public Authenticator { + public: + void authenticateClient( + std::unique_ptr& client, + const std::map& extraCredentials, + arrow::flight::AddCallHeaders& headerWriter) override {} +}; + +class NoOpAuthenticatorFactory : public AuthenticatorFactory { + public: + static constexpr const std::string_view kNoOpAuthenticatorName{"none"}; + + NoOpAuthenticatorFactory() : AuthenticatorFactory{kNoOpAuthenticatorName} {} + + NoOpAuthenticatorFactory(std::string_view name) + : AuthenticatorFactory{name} {} + + std::shared_ptr newAuthenticator( + const std::shared_ptr config) override { + return std::make_shared(); + } +}; + +} // namespace facebook::presto::connector::arrow_flight::auth diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/CMakeLists.txt b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/CMakeLists.txt new file mode 100644 index 0000000000000..1e7eba3154a0e --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/CMakeLists.txt @@ -0,0 +1,15 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +add_library(presto_flight_connector_auth Authenticator.cpp) + +target_link_libraries(presto_flight_connector_auth + presto_flight_connector_utils velox_exception) diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorDataTypeTest.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorDataTypeTest.cpp new file mode 100644 index 0000000000000..dd12509793023 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorDataTypeTest.cpp @@ -0,0 +1,355 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "arrow/testing/gtest_util.h" +#include "folly/init/Init.h" +#include "glog/logging.h" +#include "gtest/gtest.h" +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h" +#include "presto_cpp/main/connectors/arrow_flight/Macros.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/FlightConnectorTestBase.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/FlightPlanBuilder.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/TestFlightServer.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.h" +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/config/Config.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/OperatorTestBase.h" + +namespace facebook::presto::connector::arrow_flight::test { + +namespace { + +namespace velox = facebook::velox; +namespace core = facebook::velox::core; +namespace exec = facebook::velox::exec; +namespace connector = facebook::velox::connector; +namespace flight = arrow::flight; + +using namespace facebook::presto::connector::arrow_flight; +using namespace facebook::presto::connector::arrow_flight::test; +using exec::test::AssertQueryBuilder; +using exec::test::OperatorTestBase; +using exec::test::PlanBuilder; + +static const std::string kFlightConnectorId = "test-flight"; + +class FlightConnectorDataTypeTest : public FlightWithServerTestBase {}; + +TEST_F(FlightConnectorDataTypeTest, booleanType) { + updateTable( + "sample-data", + makeArrowTable( + {"bool_col"}, {makeBooleanArray({true, false, true, false})})); + + auto boolVec = makeFlatVector({true, false, true, false}); + + core::PlanNodePtr plan; + plan = FlightPlanBuilder() + .flightTableScan(velox::ROW({"bool_col"}, {velox::BOOLEAN()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({boolVec})); +} + +TEST_F(FlightConnectorDataTypeTest, integerTypes) { + updateTable( + "sample-data", + makeArrowTable( + {"tinyint_col", "smallint_col", "integer_col", "bigint_col"}, + {makeNumericArray( + {-128, 0, 127, std::numeric_limits::max()}), + makeNumericArray( + {-32768, 0, 32767, std::numeric_limits::max()}), + makeNumericArray( + {-2147483648, + 0, + 2147483647, + std::numeric_limits::max()}), + makeNumericArray( + {-3435678987654321234LL, + 0, + 4527897896541234567LL, + std::numeric_limits::max()})})); + + auto tinyintVec = makeFlatVector( + {-128, 0, 127, std::numeric_limits::max()}); + + auto smallintVec = makeFlatVector( + {-32768, 0, 32767, std::numeric_limits::max()}); + + auto integerVec = makeFlatVector( + {-2147483648, 0, 2147483647, std::numeric_limits::max()}); + + auto bigintVec = makeFlatVector( + {-3435678987654321234LL, + 0, + 4527897896541234567LL, + std::numeric_limits::max()}); + + core::PlanNodePtr plan; + plan = FlightPlanBuilder() + .flightTableScan(velox::ROW( + {"tinyint_col", "smallint_col", "integer_col", "bigint_col"}, + {velox::TINYINT(), + velox::SMALLINT(), + velox::INTEGER(), + velox::BIGINT()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults( + makeRowVector({tinyintVec, smallintVec, integerVec, bigintVec})); +} + +TEST_F(FlightConnectorDataTypeTest, realType) { + updateTable( + "sample-data", + makeArrowTable( + {"real_col", "double_col"}, + {makeNumericArray( + {std::numeric_limits::min(), + 0.0f, + 3.14f, + std::numeric_limits::max()}), + makeNumericArray( + {std::numeric_limits::min(), + 0.0, + 3.14159, + std::numeric_limits::max()})})); + + auto realVec = makeFlatVector( + {std::numeric_limits::min(), + 0.0f, + 3.14f, + std::numeric_limits::max()}); + auto doubleVec = makeFlatVector( + {std::numeric_limits::min(), + 0.0, + 3.14159, + std::numeric_limits::max()}); + + core::PlanNodePtr plan; + plan = FlightPlanBuilder() + .flightTableScan(velox::ROW( + {"real_col", "double_col"}, {velox::REAL(), velox::DOUBLE()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({realVec, doubleVec})); +} + +TEST_F(FlightConnectorDataTypeTest, varcharType) { + updateTable( + "sample-data", + makeArrowTable( + {"varchar_col"}, {makeStringArray({"Hello", "World", "India"})})); + + auto vec = makeFlatVector( + {facebook::velox::StringView("Hello"), + facebook::velox::StringView("World"), + facebook::velox::StringView("India")}); + + core::PlanNodePtr plan; + plan = FlightPlanBuilder() + .flightTableScan(velox::ROW({"varchar_col"}, {velox::VARCHAR()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({vec})); +} + +TEST_F(FlightConnectorDataTypeTest, timestampType) { + auto timestampValues = + std::vector{1622538000, 1622541600, 1622545200}; + + updateTable( + "sample-data", + makeArrowTable( + {"timestampsec_col", "timestampmilli_col", "timestampmicro_col"}, + {makeTimestampArray(timestampValues, arrow::TimeUnit::SECOND), + makeTimestampArray(timestampValues, arrow::TimeUnit::MILLI), + makeTimestampArray(timestampValues, arrow::TimeUnit::MICRO)})); + + std::vector veloxTimestampSec; + for (const auto& ts : timestampValues) { + veloxTimestampSec.emplace_back(ts, 0); // Assuming 0 microseconds part + } + + auto timestampSecCol = + makeFlatVector(veloxTimestampSec); + + std::vector veloxTimestampMilli; + for (const auto& ts : timestampValues) { + veloxTimestampMilli.emplace_back( + ts / 1000, (ts % 1000) * 1000000); // Convert to seconds and nanoseconds + } + + auto timestampMilliCol = + makeFlatVector(veloxTimestampMilli); + + std::vector veloxTimestampMicro; + for (const auto& ts : timestampValues) { + veloxTimestampMicro.emplace_back( + ts / 1000000, + (ts % 1000000) * 1000); // Convert to seconds and nanoseconds + } + + auto timestampMicroCol = + makeFlatVector(veloxTimestampMicro); + + core::PlanNodePtr plan; + plan = + FlightPlanBuilder() + .flightTableScan(velox::ROW( + {"timestampsec_col", "timestampmilli_col", "timestampmicro_col"}, + {velox::TIMESTAMP(), velox::TIMESTAMP(), velox::TIMESTAMP()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector( + {timestampSecCol, timestampMilliCol, timestampMicroCol})); +} + +TEST_F(FlightConnectorDataTypeTest, dateDayType) { + std::vector datesDay = {18748, 18749, 18750}; // Days since epoch + std::vector datesMilli = { + 1622538000000, 1622541600000, 1622545200000}; // Milliseconds since epoch + + updateTable( + "sample-data", + makeArrowTable( + {"daydate_col", "daymilli_col"}, + {makeNumericArray(datesDay), + makeNumericArray(datesMilli)})); + + auto dateVec = makeFlatVector(datesDay); + auto milliVec = makeFlatVector(datesMilli); + + core::PlanNodePtr plan; + plan = FlightPlanBuilder() + .flightTableScan(velox::ROW({"daydate_col"}, {velox::DATE()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({dateVec})); + + plan = FlightPlanBuilder() + .flightTableScan(velox::ROW({"daymilli_col"}, {velox::DATE()})) + .planNode(); + + VELOX_ASSERT_THROW( + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({milliVec})), + "Unable to convert 'tdm' ArrowSchema format type to Velox"); +} + +TEST_F(FlightConnectorDataTypeTest, decimalType) { + std::vector decimalValuesBigInt = { + 123456789012345678, + -123456789012345678, + std::numeric_limits::max()}; + std::vector> decimalArrayVec; + decimalArrayVec.push_back(makeDecimalArray(decimalValuesBigInt, 18, 2)); + updateTable( + "sample-data", makeArrowTable({"decimal_col_bigint"}, decimalArrayVec)); + auto decimalVecBigInt = makeFlatVector(decimalValuesBigInt); + + core::PlanNodePtr plan; + plan = FlightPlanBuilder() + .flightTableScan(velox::ROW( + {"decimal_col_bigint"}, + {velox::DECIMAL(18, 2)})) // precision can't be 0 and < scale + .planNode(); + + // Execute the query and assert the results + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({decimalVecBigInt})); +} + +TEST_F(FlightConnectorDataTypeTest, allTypes) { + auto timestampValues = + std::vector{1622550000, 1622553600, 1622557200}; + + auto sampleTable = makeArrowTable( + {"id", + "daydate_col", + "timestamp_col", + "varchar_col", + "real_col", + "int_col", + "bool_col"}, + {makeNumericArray({1, 2, 3}), + makeNumericArray({18748, 18749, 18750}), + makeTimestampArray(timestampValues, arrow::TimeUnit::SECOND), + makeStringArray({"apple", "banana", "cherry"}), + makeNumericArray({3.14, 2.718, 1.618}), + makeNumericArray( + {-32768, 32767, std::numeric_limits::max()}), + makeBooleanArray({true, false, true})}); + + updateTable("gen-data", sampleTable); + + auto dateVec = makeFlatVector({18748, 18749, 18750}); + + std::vector veloxTimestampSec; + for (const auto& ts : timestampValues) { + veloxTimestampSec.emplace_back(ts, 0); // Assuming 0 microseconds part + } + auto timestampSecVec = + makeFlatVector(veloxTimestampSec); + + auto stringVec = makeFlatVector( + {facebook::velox::StringView("apple"), + facebook::velox::StringView("banana"), + facebook::velox::StringView("cherry")}); + auto realVec = makeFlatVector({3.14, 2.718, 1.618}); + auto intVec = makeFlatVector( + {-32768, 32767, std::numeric_limits::max()}); + auto boolVec = makeFlatVector({true, false, true}); + + core::PlanNodePtr plan; + plan = FlightPlanBuilder() + .flightTableScan(velox::ROW( + {"daydate_col", + "timestamp_col", + "varchar_col", + "real_col", + "int_col", + "bool_col"}, + {velox::DATE(), + velox::TIMESTAMP(), + velox::VARCHAR(), + velox::DOUBLE(), + velox::INTEGER(), + velox::BOOLEAN()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"gen-data"})) + .assertResults(makeRowVector( + {dateVec, timestampSecVec, stringVec, realVec, intVec, boolVec})); +} + +} // namespace + +} // namespace facebook::presto::connector::arrow_flight::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorTest.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorTest.cpp new file mode 100644 index 0000000000000..90d9a9a6c8930 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorTest.cpp @@ -0,0 +1,204 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h" +#include "arrow/testing/gtest_util.h" +#include "folly/init/Init.h" +#include "glog/logging.h" +#include "gtest/gtest.h" +#include "presto_cpp/main/connectors/arrow_flight/Macros.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/FlightConnectorTestBase.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/FlightPlanBuilder.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/TestFlightServer.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.h" +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/config/Config.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/OperatorTestBase.h" + +namespace facebook::presto::connector::arrow_flight::test { + +namespace { + +namespace velox = facebook::velox; +namespace core = facebook::velox::core; +namespace exec = facebook::velox::exec; +namespace connector = facebook::velox::connector; +namespace flight = arrow::flight; + +using namespace facebook::presto::connector::arrow_flight; +using namespace facebook::presto::connector::arrow_flight::test; +using exec::test::AssertQueryBuilder; +using exec::test::OperatorTestBase; + +static const std::string kFlightConnectorId = "test-flight"; + +class FlightConnectorTest : public FlightWithServerTestBase {}; + +TEST_F(FlightConnectorTest, invalidSplitTest) { + auto plan = FlightPlanBuilder() + .flightTableScan(velox::ROW({{"id", velox::BIGINT()}})) + .planNode(); + + VELOX_ASSERT_THROW( + velox::exec::test::AssertQueryBuilder(plan) + .splits(makeSplits({"unknown"})) + .copyResults(pool()), + "Server replied with error"); +} + +TEST_F(FlightConnectorTest, dataSourceCreationTest) { + // missing columnHandle test + auto plan = + FlightPlanBuilder() + .flightTableScan( + velox::ROW({"id", "value"}, {velox::BIGINT(), velox::INTEGER()}), + {{"id", std::make_shared("id")}}, + false /*createDefaultColumnHandles*/) + .planNode(); + + VELOX_ASSERT_THROW( + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .copyResults(pool()), + "missing columnHandle for column 'value'"); +} + +TEST_F(FlightConnectorTest, dataSourceTest) { + updateTable( + "sample-data", + makeArrowTable( + {"id", "value", "unsigned"}, + {makeNumericArray( + {1, 12, 2, std::numeric_limits::max()}), + makeNumericArray( + {41, 42, 43, std::numeric_limits::min()}), + // note that velox doesn't support unsigned types + // connector should still be able to query such tables + // as long as this specifc column isn't requested. + makeNumericArray( + {101, 102, 12, std::numeric_limits::max()})})); + + auto idColumn = std::make_shared("id"); + auto idVec = + makeFlatVector({1, 12, 2, std::numeric_limits::max()}); + + auto valueColumn = std::make_shared("value"); + auto valueVec = makeFlatVector( + {41, 42, 43, std::numeric_limits::min()}); + + core::PlanNodePtr plan; + + // direct test + plan = FlightPlanBuilder() + .flightTableScan(velox::ROW( + {"id", "value"}, {velox::BIGINT(), velox::INTEGER()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({idVec, valueVec})); + + VELOX_ASSERT_THROW( + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"}, std::vector{""})) + .assertResults(makeRowVector({idVec, valueVec})), + "URI has empty scheme"); + + VELOX_ASSERT_THROW( + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"}, std::vector{})) + .assertResults(makeRowVector({idVec, valueVec})), + "default host or port is missing"); + + // column alias test + plan = + FlightPlanBuilder() + .flightTableScan( + velox::ROW({"ducks", "id"}, {velox::BIGINT(), velox::BIGINT()}), + {{"ducks", idColumn}}) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({idVec, idVec})); + + // invalid columnHandle test + plan = FlightPlanBuilder() + .flightTableScan(velox::ROW( + {"ducks", "value"}, {velox::BIGINT(), velox::INTEGER()})) + .planNode(); + + VELOX_ASSERT_THROW( + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .copyResults(pool()), + "column with name 'ducks' not found"); +} + +class FlightConnectorTestDefaultServer : public FlightWithServerTestBase { + public: + FlightConnectorTestDefaultServer() + : FlightWithServerTestBase(std::make_shared( + std::unordered_map{ + {FlightConfig::kDefaultServerHost, CONNECT_HOST}, + {FlightConfig::kDefaultServerPort, + std::to_string(LISTEN_PORT)}})) {} +}; + +TEST_F(FlightConnectorTestDefaultServer, dataSourceTest) { + updateTable( + "sample-data", + makeArrowTable( + {"id", "value"}, + {makeNumericArray( + {1, 12, 2, std::numeric_limits::max()}), + makeNumericArray( + {41, 42, 43, std::numeric_limits::min()})})); + + auto idColumn = std::make_shared("id"); + auto idVec = + makeFlatVector({1, 12, 2, std::numeric_limits::max()}); + + auto valueColumn = std::make_shared("value"); + auto valueVec = makeFlatVector( + {41, 42, 43, std::numeric_limits::min()}); + + core::PlanNodePtr plan; + + // direct test + plan = FlightPlanBuilder() + .flightTableScan(velox::ROW( + {"id", "value"}, {velox::BIGINT(), velox::INTEGER()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({idVec, valueVec})); + + AssertQueryBuilder(plan) + .splits(makeSplits( + {"sample-data"}, + std::vector{})) // Using default connector + .assertResults(makeRowVector({idVec, valueVec})); +} + +} // namespace + +} // namespace facebook::presto::connector::arrow_flight::test + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + folly::Init init{&argc, &argv, false}; + return RUN_ALL_TESTS(); +} diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorTlsTest.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorTlsTest.cpp new file mode 100644 index 0000000000000..d5a3ef0f8cdd4 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorTlsTest.cpp @@ -0,0 +1,126 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "arrow/flight/types.h" +#include "arrow/testing/gtest_util.h" +#include "folly/init/Init.h" +#include "glog/logging.h" +#include "gtest/gtest.h" +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/FlightConnectorTestBase.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/FlightPlanBuilder.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.h" +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/config/Config.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/OperatorTestBase.h" + +namespace facebook::presto::connector::arrow_flight::test { + +namespace { + +namespace velox = facebook::velox; +namespace core = facebook::velox::core; +namespace exec = facebook::velox::exec; +namespace connector = facebook::velox::connector; +namespace flight = arrow::flight; + +using namespace facebook::presto::connector::arrow_flight; +using namespace facebook::presto::connector::arrow_flight::test; +using exec::test::AssertQueryBuilder; +using exec::test::OperatorTestBase; + +class FlightConnectorTlsTestBase : public FlightWithServerTestBase { + protected: + explicit FlightConnectorTlsTestBase( + std::shared_ptr config) + : FlightWithServerTestBase( + std::move(config), + createFlightServerOptions( + true, /* isSecure */ + "./data/tls_certs/server.crt", + "./data/tls_certs/server.key")) {} + + void executeTest(bool isPositiveTest = true) { + updateTable( + "sample-data", + makeArrowTable( + {"id"}, + {makeNumericArray( + {1, 12, 2, std::numeric_limits::max()})})); + + auto idVec = makeFlatVector( + {1, 12, 2, std::numeric_limits::max()}); + + auto plan = FlightPlanBuilder() + .flightTableScan(velox::ROW({"id"}, {velox::BIGINT()})) + .planNode(); + + auto loc = std::vector{ + fmt::format("grpc+tls://{}:{}", CONNECT_HOST, LISTEN_PORT)}; + if (isPositiveTest) { + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"}, loc)) + .assertResults(makeRowVector({idVec})); + } else { + VELOX_ASSERT_THROW( + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"}, loc)) + .assertResults(makeRowVector({idVec})), + "Server replied with error"); + } + } +}; + +class FlightConnectorTlsTest : public FlightConnectorTlsTestBase { + protected: + explicit FlightConnectorTlsTest() + : FlightConnectorTlsTestBase(std::make_shared( + std::unordered_map{ + {FlightConfig::kServerVerify, "true"}, + {FlightConfig::kServerSslCertificate, + "./data/tls_certs/ca.crt"}})) {} +}; + +TEST_F(FlightConnectorTlsTest, tlsTest) { + executeTest(); +} + +class FlightConnectorTlsNoCertValidationTest + : public FlightConnectorTlsTestBase { + protected: + explicit FlightConnectorTlsNoCertValidationTest() + : FlightConnectorTlsTestBase(std::make_shared( + std::unordered_map{ + {FlightConfig::kServerVerify, "false"}})) {} +}; + +TEST_F(FlightConnectorTlsNoCertValidationTest, tlsNoCertValidationTest) { + executeTest(); +} + +class FlightConnectorTlsNoCertTest : public FlightConnectorTlsTestBase { + protected: + FlightConnectorTlsNoCertTest() + : FlightConnectorTlsTestBase(std::make_shared( + std::unordered_map{ + {FlightConfig::kServerVerify, "true"}})) {} +}; + +TEST_F(FlightConnectorTlsNoCertTest, tlsNoCertTest) { + executeTest(false /* isPositiveTest */); +} + +} // namespace + +} // namespace facebook::presto::connector::arrow_flight::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/CMakeLists.txt b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/CMakeLists.txt new file mode 100644 index 0000000000000..e427167a93231 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/CMakeLists.txt @@ -0,0 +1,42 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +add_subdirectory(utils) + +add_executable(presto_flight_connector_infra_test TestFlightServerTest.cpp) + +add_test(presto_flight_connector_infra_test presto_flight_connector_infra_test) + +target_link_libraries(presto_flight_connector_infra_test + presto_flight_connector_test_lib gtest gtest_main glog) + +add_executable( + presto_flight_connector_test + ArrowFlightConnectorTest.cpp ArrowFlightConnectorTlsTest.cpp + ArrowFlightConnectorDataTypeTest.cpp FlightConfigTest.cpp) + +set(DATA_DIR "${CMAKE_CURRENT_SOURCE_DIR}/data/tls_certs") + +add_custom_target( + copy_flight_test_data ALL + COMMAND ${CMAKE_COMMAND} -E copy_directory ${DATA_DIR} + $/data/tls_certs) + +add_test(presto_flight_connector_test presto_flight_connector_test) + +target_link_libraries( + presto_flight_connector_test + velox_exec_test_lib + presto_flight_connector + gtest + gtest_main + presto_flight_connector_test_lib + presto_protocol) diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/FlightConfigTest.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/FlightConfigTest.cpp new file mode 100644 index 0000000000000..8777f6f3f8956 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/FlightConfigTest.cpp @@ -0,0 +1,49 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/FlightConfig.h" +#include "gtest/gtest.h" + +namespace facebook::presto::connector::arrow_flight::test { + +TEST(FlightConfigTest, defaultConfig) { + auto rawConfig = std::make_shared( + std::move(std::unordered_map{})); + auto config = FlightConfig(rawConfig); + ASSERT_EQ(config.authenticatorName(), "none"); + ASSERT_EQ(config.defaultServerHostname(), std::nullopt); + ASSERT_EQ(config.defaultServerPort(), std::nullopt); + ASSERT_EQ(config.defaultServerSslEnabled(), false); + ASSERT_EQ(config.serverVerify(), true); + ASSERT_EQ(config.serverSslCertificate(), folly::none); +} + +TEST(FlightConfigTest, overrideConfig) { + std::unordered_map configMap = { + {FlightConfig::kAuthenticatorName, "my-authenticator"}, + {FlightConfig::kDefaultServerHost, "my-server-host"}, + {FlightConfig::kDefaultServerPort, "9000"}, + {FlightConfig::kDefaultServerSslEnabled, "true"}, + {FlightConfig::kServerVerify, "false"}, + {FlightConfig::kServerSslCertificate, "my-cert.crt"}}; + auto config = FlightConfig( + std::make_shared(std::move(configMap))); + ASSERT_EQ(config.authenticatorName(), "my-authenticator"); + ASSERT_EQ(config.defaultServerHostname(), "my-server-host"); + ASSERT_EQ(config.defaultServerPort(), 9000); + ASSERT_EQ(config.defaultServerSslEnabled(), true); + ASSERT_EQ(config.serverVerify(), false); + ASSERT_EQ(config.serverSslCertificate(), "my-cert.crt"); +} + +} // namespace facebook::presto::connector::arrow_flight::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/TestFlightServerTest.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/TestFlightServerTest.cpp new file mode 100644 index 0000000000000..615b9868e359f --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/TestFlightServerTest.cpp @@ -0,0 +1,84 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/TestFlightServer.h" +#include "arrow/api.h" +#include "arrow/flight/api.h" +#include "arrow/testing/gtest_util.h" +#include "glog/logging.h" +#include "gtest/gtest.h" + +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.h" + +namespace { + +using namespace facebook::presto::connector::arrow_flight::test; +using namespace arrow::flight; + +class TestFlightServerTest : public testing::Test { + public: + static void SetUpTestSuite() { + server = std::make_unique(); + ASSERT_OK_AND_ASSIGN(auto loc, Location::ForGrpcTcp("127.0.0.1", 0)); + ASSERT_OK(server->Init(FlightServerOptions(loc))); + } + + static void TearDownTestSuite() { + ASSERT_OK(server->Shutdown()); + } + + static void updateTable( + std::string name, + std::shared_ptr table) { + server->updateTable(std::move(name), std::move(table)); + } + + void SetUp() { + ASSERT_OK_AND_ASSIGN( + auto loc, Location::ForGrpcTcp("localhost", server->port())); + ASSERT_OK_AND_ASSIGN(client, FlightClient::Connect(loc)); + } + + std::unique_ptr client; + static std::unique_ptr server; +}; + +std::unique_ptr TestFlightServerTest::server; + +TEST_F(TestFlightServerTest, basicTest) { + auto sampleTable = makeArrowTable( + {"id", "value"}, + {makeNumericArray({1, 2}), + makeNumericArray({41, 42})}); + updateTable("sample-data", sampleTable); + + ASSERT_RAISES(KeyError, client->DoGet(Ticket{"empty"})); + + auto emptyTable = makeArrowTable({}, {}); + updateTable("empty", emptyTable); + + ASSERT_RAISES(KeyError, client->DoGet(Ticket{"non-existent-table"})); + + ASSERT_OK_AND_ASSIGN(auto reader, client->DoGet(Ticket{"empty"})); + ASSERT_OK_AND_ASSIGN(auto actual, reader->ToTable()); + EXPECT_TRUE(actual->Equals(*emptyTable)); + + ASSERT_OK_AND_ASSIGN(reader, client->DoGet(Ticket{"sample-data"})); + ASSERT_OK_AND_ASSIGN(actual, reader->ToTable()); + EXPECT_TRUE(actual->Equals(*sampleTable)); + + server->removeTable("sample-data"); + ASSERT_RAISES(KeyError, client->DoGet(Ticket{"sample-data"})); +} + +} // namespace diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/README.md b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/README.md new file mode 100644 index 0000000000000..bac4938fc47d5 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/README.md @@ -0,0 +1,6 @@ +### Placeholder TLS Certificates for Arrow Flight Connector Unit Testing +The `tls_certs` directory contains placeholder TLS certificates generated for unit testing the Arrow Flight Connector with TLS enabled. These certificates are not intended for production use and should only be used in the context of unit tests. + +### Generating TLS Certificates +To create the TLS certificates and keys inside the `tls_certs` folder, run the following command: +`./generate_tls_certs.sh` diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/generate_tls_certs.sh b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/generate_tls_certs.sh new file mode 100755 index 0000000000000..718f313c70a75 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/generate_tls_certs.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +# Set directory for certificates and keys. +CERT_DIR="./tls_certs" +mkdir -p $CERT_DIR + +# Dummy values for the certificates. +COUNTRY="US" +STATE="State" +LOCALITY="City" +ORGANIZATION="MyOrg" +ORG_UNIT="MyUnit" +COMMON_NAME="MyCA" +SERVER_CN="server.mydomain.com" + +# Step 1: Generate CA private key and self-signed certificate. +openssl genpkey -algorithm RSA -out $CERT_DIR/ca.key +openssl req -key $CERT_DIR/ca.key -new -x509 -out $CERT_DIR/ca.crt -days 365000 \ + -subj "/C=$COUNTRY/ST=$STATE/L=$LOCALITY/O=$ORGANIZATION/OU=$ORG_UNIT/CN=$COMMON_NAME" + +# Step 2: Generate server private key. +openssl genpkey -algorithm RSA -out $CERT_DIR/server.key + +# Step 3: Generate server certificate signing request (CSR). +openssl req -new -key $CERT_DIR/server.key -out $CERT_DIR/server.csr \ + -subj "/C=$COUNTRY/ST=$STATE/L=$LOCALITY/O=$ORGANIZATION/OU=$ORG_UNIT/CN=$SERVER_CN" \ + -addext "subjectAltName=DNS:$COMMON_NAME,DNS:localhost" \ + +# Step 4: Sign server CSR with the CA certificate to generate the server certificate. +openssl x509 -req -in $CERT_DIR/server.csr -CA $CERT_DIR/ca.crt -CAkey $CERT_DIR/ca.key \ + -CAcreateserial -out $CERT_DIR/server.crt -days 365000 \ + -extfile <(printf "subjectAltName=DNS:$COMMON_NAME,DNS:localhost") + +# Step 5: Output the generated files. +echo "Certificate Authority (CA) certificate: $CERT_DIR/ca.crt" +echo "Server certificate: $CERT_DIR/server.crt" +echo "Server private key: $CERT_DIR/server.key" + +# Step 6: Remove unused files. +rm -rf $CERT_DIR/server.csr $CERT_DIR/ca.srl $CERT_DIR/ca.key diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/ca.crt b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/ca.crt new file mode 100644 index 0000000000000..6740e89c54e17 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/ca.crt @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDmzCCAoOgAwIBAgIUf+rP48iL39yGlAfFQTIp5bmM4uQwDQYJKoZIhvcNAQEL +BQAwXDELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYDVQQHDARDaXR5 +MQ4wDAYDVQQKDAVNeU9yZzEPMA0GA1UECwwGTXlVbml0MQ0wCwYDVQQDDARNeUNB +MCAXDTI0MTIwMzExMDQxMVoYDzMwMjQwNDA1MTEwNDExWjBcMQswCQYDVQQGEwJV +UzEOMAwGA1UECAwFU3RhdGUxDTALBgNVBAcMBENpdHkxDjAMBgNVBAoMBU15T3Jn +MQ8wDQYDVQQLDAZNeVVuaXQxDTALBgNVBAMMBE15Q0EwggEiMA0GCSqGSIb3DQEB +AQUAA4IBDwAwggEKAoIBAQCliiXIcSmxXAAq2k/XjcZniDgEDCxWKZGiV8JBiJwY +MMBJtqcVzWfiDpO2u6d1dfGb6utlRW+1dnwupzURCMmZff4bqlPx4ZejRXDrWzKz +08WSpDVZwC2H5XOllwK36Cn4gvPRe3YWVcdDGHy7GL+zsJENvawJj0BH952MU4bk +sV52zEkN291bfN9sSYfT1NCJuLPM0Qsf97DeQ+wHXEw+t4XVMF3FQbciQp0y6CnA +wfFFN14WDiWxukP1I3kuDYYA6h/WJCQMp5rU2NCB9nIQrulYRxFaepMYENLxgAyj +gFaoRh2Kt2k7XKv6WOa6CmYm2dZERPlbA+oNAHkaHw6lAgMBAAGjUzBRMB0GA1Ud +DgQWBBSN+3vRlXGjs6c+rN94qgEnkPLl3DAfBgNVHSMEGDAWgBSN+3vRlXGjs6c+ +rN94qgEnkPLl3DAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQAb +L40Oe2b/2xdUSyjqXJceVxaeA291fCpnu1C1JliP0hNI3fu9jjZhXHQoYub/4mod +8lriEDIcOCCiUfmi404akpqQHuBmOHaKEOtaaQkezjPsYnUra+O2ssqUo2zto5bK +gR0LGsb+4AO0bDvq+QVI6kEQqAAIf6qC+kpg/jV4iKJ1J6Qw4R3QppYBm6SQcfvI +hfUfDSO6SNfy0f/ZVCavbJIP9zG/BfAD9DEERocw03PiN5bm4IXJ3HH8rxyuBfJ5 +Eg/fPP5TlZ2H7Kqb3VgVBGWJtNXWmJphHyraBJTEuxgXWvl6AaW0P/3dsJi3rfdD +zDIT7AmENLCom8Gl0bgM +-----END CERTIFICATE----- diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/server.crt b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/server.crt new file mode 100644 index 0000000000000..92c91f2d613b0 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/server.crt @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDtTCCAp2gAwIBAgIUUhmhZP94nIowrg2EarzfEBp6W1EwDQYJKoZIhvcNAQEL +BQAwXDELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYDVQQHDARDaXR5 +MQ4wDAYDVQQKDAVNeU9yZzEPMA0GA1UECwwGTXlVbml0MQ0wCwYDVQQDDARNeUNB +MCAXDTI0MTIwMzExMDQxMVoYDzMwMjQwNDA1MTEwNDExWjBrMQswCQYDVQQGEwJV +UzEOMAwGA1UECAwFU3RhdGUxDTALBgNVBAcMBENpdHkxDjAMBgNVBAoMBU15T3Jn +MQ8wDQYDVQQLDAZNeVVuaXQxHDAaBgNVBAMME3NlcnZlci5teWRvbWFpbi5jb20w +ggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDSxC4zCC4GFZbX+fdFgWbL +sj4PortyOM7mzRjNaQ3M0FTSEy5xET9C2qFlBCJ7AL7DlbSLmSckYY/FkdfMqNN2 ++NZ0Dy2d6bZN+ly5N/QBVnyS/5HVC3MXa6Y2BmFXiBnczWfGBwj+uVHlKOUWUyNi +EyUkhuPwtYXkFmJoqBxJSPC6cxX6NzMujnwCF18dUf0Vra44osu4moaovmg3c9jM +cBtmafFs9F54FoAEuLotjISVEa7VY6th5RxXJHpgas+0R5EBddGYKbTRiUYjht7r +pS+An0ey02oOjEWdqLnQSg/SUGKuRXULyE5l1A0HfNQtvepUQotb9ull1F7OrbfB +AgMBAAGjXjBcMBoGA1UdEQQTMBGCBE15Q0GCCWxvY2FsaG9zdDAdBgNVHQ4EFgQU +vnCLWjre4jqkKzC24psCPh1oIQwwHwYDVR0jBBgwFoAUjft70ZVxo7OnPqzfeKoB +J5Dy5dwwDQYJKoZIhvcNAQELBQADggEBAJCiJgtTw/7b/g3QnJfM4teQkFS440Ii +weqQJMoP6als8Fc3opPKv9eC5w0wqaLlIdwJjzGM5PmCAtGVafo22TbqhZyQdzQu +TUKv1DaVF0JBVAGVxTSDIK9r5Ww4mDAQnQENLC6soS3AvYDEi+8667YLoNNdhRCX +q2D5v76UN45idiShppxOw53whsvpHv+wyqcdse7DhgM9boCbx51Uvv3l/AEToyaj +S1xeIkBwNpSYU0ax2Lr1j2yoKbzAa3MHy8Php+T5CGji02+HwwlvlPDLtw8q5gHw +BLSwlAHgclPxUTWNNoCqjfX8Bi083+QDCLm0rgQ45xljNDbFAF1Y5hA= +-----END CERTIFICATE----- diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/server.key b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/server.key new file mode 100644 index 0000000000000..2cdf5750a4753 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/server.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQDSxC4zCC4GFZbX ++fdFgWbLsj4PortyOM7mzRjNaQ3M0FTSEy5xET9C2qFlBCJ7AL7DlbSLmSckYY/F +kdfMqNN2+NZ0Dy2d6bZN+ly5N/QBVnyS/5HVC3MXa6Y2BmFXiBnczWfGBwj+uVHl +KOUWUyNiEyUkhuPwtYXkFmJoqBxJSPC6cxX6NzMujnwCF18dUf0Vra44osu4moao +vmg3c9jMcBtmafFs9F54FoAEuLotjISVEa7VY6th5RxXJHpgas+0R5EBddGYKbTR +iUYjht7rpS+An0ey02oOjEWdqLnQSg/SUGKuRXULyE5l1A0HfNQtvepUQotb9ull +1F7OrbfBAgMBAAECggEAAxbZuuESGGAMMm9HLGhKHgbHU8gnv2Phdbrka+SYBYg5 +UYzTHLh3FwEsjd4VnaweJ7CN1WDb1NvWmTum/DCebJ1HKqtjKLAZfk8q2TLGmXdL +pzWOdQ8MX1fKP2sIlcl0kFbNCE8vprjneDyBLtqOK36eiAh/fl6BQ12QAMLjyv/L +OwXSY4ESs/RzxRzFgdT98cDZFL7y0FVIjJo/Q5lfW9UwwSfw8tOLNXKTYwPHqIfJ +NjfWD7IqztQlnanyRXv5dScp80i8p9qgH0i8YfVBHZDeOmHGLcltilLRZ0dQ/X0g +Lrr0aIO3iLhmTIkJRzUnGeyvDjxcPINvRSBBwXy04QKBgQDpFJa/EwSsWj8586oh +xgm0Z3q+FiEeCe7aLLPcXAS2EDvix5ibJDT2y1Aa/kXq25S53npa/7Ov6TJs5H4g +eyshDtR1wVhz+rIggREiX/sagkhwnNsssUZFv5t9PdnaFXpVnH49m5Qc8HO3owtN +t8EGSRcAQ4o/fLWLs51qd38cIQKBgQDnfd8YPyDQ03xDC/3+Qrypyc/xhGnCuj7w +ZeA5iEyTnnNxL0a0B6PWcSk2BZReMNQKgYtipnsOQKtwHMttxtXYs/VQpeB4KoWE +zEwW0fV3MMsXN+nVJlEZnVaTbmYXknjeZrh/rNjsY96yxw8NtvAuYSpnqtr3N2nd +iMQ3G/QnoQKBgGMi+bdNvIgeXpQkmrGAzTHpbaCaQv3G1cwAhYPts6dIomAj6znZ +nZl3ApxomI57VPf1s+8uoVvqASOl0Cu6l66Y4y8uzJOQBuGiZApN7rzouy0C2opY +4H3cMKOFgjqrNfxh8qP7n3TrpRxvgehNhxFIVzsqfwvf3EwOWp8lMnBhAoGAZ25E +Ge9K2ENGCCb5i3uCFFLJiF3ja1AQAxVhxBL0NBjd97pp2tJ3D79r7Gk9y4ABndgX +0TIVVV7ruqIC+r+WmMZ/W1NiIg7NrXIipSeWh3TTqUIgRk5iehFkt2biUrHtM2Gu +Gc2+9pAA1tw+C6CrW+2qJrueLksiEAulsAHba0ECgYBIgIiY+Gx+XecEgCwAhWcn +GzNDAAlA4IgBjHpUtIByflzQDqlECKXjPbVBKfyq6eLt40upFmQCLsn+AkiQau8A +3cFAK9wJOAHv9KuWDrbHyhRE9CrJ6BqsY2goC3LiFCTgJy1TrRl6CDaFzHivONwF +LNPflYk5s376UWqxC+HtIA== +-----END PRIVATE KEY----- diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/CMakeLists.txt b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/CMakeLists.txt new file mode 100644 index 0000000000000..3ddf148e5671b --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/CMakeLists.txt @@ -0,0 +1,19 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +add_library( + presto_flight_connector_test_lib + TestFlightServer.cpp FlightConnectorTestBase.cpp Utils.cpp + FlightPlanBuilder.cpp) + +target_link_libraries( + presto_flight_connector_test_lib arrow ${ArrowFlight_LIBRARIES} + velox_exception presto_flight_connector_utils velox_exec_test_lib) diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/FlightConnectorTestBase.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/FlightConnectorTestBase.cpp new file mode 100644 index 0000000000000..24fceeab2b4f9 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/FlightConnectorTestBase.cpp @@ -0,0 +1,93 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/FlightConnectorTestBase.h" +#include "arrow/testing/gtest_util.h" +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h" +#include "presto_cpp/main/connectors/arrow_flight/Macros.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.h" + +namespace facebook::presto::connector::arrow_flight::test { + +namespace connector = velox::connector; +namespace core = velox::core; + +using namespace arrow::flight; +using velox::exec::test::OperatorTestBase; + +void FlightConnectorTestBase::SetUp() { + OperatorTestBase::SetUp(); + + if (!velox::connector::hasConnectorFactory( + presto::connector::arrow_flight::ArrowFlightConnectorFactory:: + kArrowFlightConnectorName)) { + connector::registerConnectorFactory( + std::make_shared< + presto::connector::arrow_flight::ArrowFlightConnectorFactory>()); + } + connector::registerConnector( + connector::getConnectorFactory( + ArrowFlightConnectorFactory::kArrowFlightConnectorName) + ->newConnector(kFlightConnectorId, config_)); +} + +void FlightConnectorTestBase::TearDown() { + connector::unregisterConnector(kFlightConnectorId); + OperatorTestBase::TearDown(); +} + +void FlightWithServerTestBase::SetUp() { + FlightConnectorTestBase::SetUp(); + + server_ = std::make_unique(); + ASSERT_OK(server_->Init(*options_)); +} + +void FlightWithServerTestBase::TearDown() { + ASSERT_OK(server_->Shutdown()); + FlightConnectorTestBase::TearDown(); +} + +std::vector> +FlightWithServerTestBase::makeSplits( + const std::initializer_list& tickets, + const std::vector& location) { + std::vector> splits; + splits.reserve(tickets.size()); + for (auto& ticket : tickets) { + splits.push_back( + std::make_shared(kFlightConnectorId, ticket, location)); + } + return splits; +} + +std::shared_ptr +FlightWithServerTestBase::createFlightServerOptions( + bool isSecure, + const std::string& certPath, + const std::string& keyPath) { + AFC_ASSIGN_OR_RAISE( + auto loc, + isSecure ? Location::ForGrpcTls(BIND_HOST, LISTEN_PORT) + : Location::ForGrpcTcp(BIND_HOST, LISTEN_PORT)); + auto options = std::make_shared(loc); + if (!isSecure) + return options; + + CertKeyPair tlsCertificate{ + .pem_cert = readFile(certPath), .pem_key = readFile(keyPath)}; + options->tls_certificates.push_back(tlsCertificate); + return options; +} + +} // namespace facebook::presto::connector::arrow_flight::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/FlightConnectorTestBase.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/FlightConnectorTestBase.h new file mode 100644 index 0000000000000..382fe6318e39f --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/FlightConnectorTestBase.h @@ -0,0 +1,94 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "arrow/flight/api.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/TestFlightServer.h" +#include "velox/common/config/Config.h" +#include "velox/connectors/Connector.h" +#include "velox/exec/tests/utils/OperatorTestBase.h" + +namespace facebook::presto::connector::arrow_flight::test { + +static const std::string kFlightConnectorId = "test-flight"; + +class FlightConnectorTestBase : public velox::exec::test::OperatorTestBase { + public: + void SetUp() override; + + void TearDown() override; + + protected: + explicit FlightConnectorTestBase( + std::shared_ptr config) + : config_{std::move(config)} {} + + FlightConnectorTestBase() + : config_{std::make_shared( + std::move(std::unordered_map{}))} {} + + private: + std::shared_ptr config_; +}; + +/// Creates and registers an arrow flight connector and +/// spawns a Flight server for testing. +/// Initially there is no data in the Flight server, +/// tests should call FlightWithServerTestBase::updateTables to populate it. +class FlightWithServerTestBase : public FlightConnectorTestBase { + public: + static constexpr const char* BIND_HOST = "127.0.0.1"; + static constexpr const char* CONNECT_HOST = "localhost"; + constexpr static int LISTEN_PORT = 5000; + + void SetUp() override; + + void TearDown() override; + + /// Convenience method which creates splits for the test flight server + std::vector> makeSplits( + const std::initializer_list& tokens, + const std::vector& location = std::vector{ + fmt::format("grpc://{}:{}", CONNECT_HOST, LISTEN_PORT)}); + + /// Add (or update) a table in the test flight server + void updateTable(std::string name, std::shared_ptr table) { + server_->updateTable(std::move(name), std::move(table)); + } + + protected: + explicit FlightWithServerTestBase( + std::shared_ptr config) + : FlightConnectorTestBase{std::move(config)}, + options_{createFlightServerOptions()} {} + + FlightWithServerTestBase() + : FlightConnectorTestBase(), options_{createFlightServerOptions()} {} + + explicit FlightWithServerTestBase( + std::shared_ptr config, + std::shared_ptr options) + : FlightConnectorTestBase{std::move(config)}, options_{options} {} + + std::shared_ptr createFlightServerOptions( + bool isSecure = false, + const std::string& certPath = "", + const std::string& keyPath = ""); + + private: + std::unique_ptr server_; + std::shared_ptr options_; +}; + +} // namespace facebook::presto::connector::arrow_flight::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/FlightPlanBuilder.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/FlightPlanBuilder.cpp new file mode 100644 index 0000000000000..42194df9c890e --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/FlightPlanBuilder.cpp @@ -0,0 +1,43 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/FlightPlanBuilder.h" +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h" + +namespace facebook::presto::connector::arrow_flight::test { + +static const std::string kFlightConnectorId = "test-flight"; + +velox::exec::test::PlanBuilder& FlightPlanBuilder::flightTableScan( + const velox::RowTypePtr& outputType, + std::unordered_map< + std::string, + std::shared_ptr> assignments, + bool createDefaultColumnHandles) { + if (createDefaultColumnHandles) { + for (const auto& name : outputType->names()) { + // Provide unaliased defaults for unmapped columns. + // `emplace` won't modify the map if the key already exists, + // so existing aliases are kept. + assignments.emplace(name, std::make_shared(name)); + } + } + + return startTableScan() + .tableHandle(std::make_shared(kFlightConnectorId)) + .outputType(outputType) + .assignments(std::move(assignments)) + .endTableScan(); +} + +} // namespace facebook::presto::connector::arrow_flight::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/FlightPlanBuilder.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/FlightPlanBuilder.h new file mode 100644 index 0000000000000..bfc75c3704585 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/FlightPlanBuilder.h @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/exec/tests/utils/PlanBuilder.h" + +namespace facebook::presto::connector::arrow_flight::test { + +class FlightPlanBuilder : public velox::exec::test::PlanBuilder { + public: + /// @brief Add a table scan node to the Plan, using the Flight connector + /// @param outputType The output type of the table scan node + /// @param assignments mapping from the column aliases to real column handles + /// @param createDefaultColumnHandles If true, generate column handles for + /// for the columns which don't have an entry in assignments + velox::exec::test::PlanBuilder& flightTableScan( + const velox::RowTypePtr& outputType, + std::unordered_map< + std::string, + std::shared_ptr> assignments = {}, + bool createDefaultColumnHandles = true); +}; + +} // namespace facebook::presto::connector::arrow_flight::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/TestFlightServer.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/TestFlightServer.cpp new file mode 100644 index 0000000000000..672ea16730fc8 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/TestFlightServer.cpp @@ -0,0 +1,34 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/TestFlightServer.h" + +namespace facebook::presto::connector::arrow_flight::test { + +using namespace arrow::flight; + +arrow::Status TestFlightServer::DoGet( + const ServerCallContext& context, + const Ticket& request, + std::unique_ptr* stream) { + auto it = tables_.find(request.ticket); + if (it == tables_.end()) { + return arrow::Status::KeyError("requested table does not exist"); + } + auto& table = it->second; + auto reader = std::make_shared(table); + *stream = std::make_unique(std::move(reader)); + return arrow::Status::OK(); +} + +} // namespace facebook::presto::connector::arrow_flight::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/TestFlightServer.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/TestFlightServer.h new file mode 100644 index 0000000000000..f9d924ad96c66 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/TestFlightServer.h @@ -0,0 +1,48 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "arrow/api.h" +#include "arrow/flight/api.h" + +namespace facebook::presto::connector::arrow_flight::test { + +/// Test Flight server which supports DoGet operations. +/// Maintains a list of named arrow tables, +/// +/// Normally, the tickets would be obtained by calling GetFlightInfo, +/// but since this is done by the coordinator this part is omitted. +/// Instead, the ticket is simply the name of the table to fetch. +class TestFlightServer : public arrow::flight::FlightServerBase { + public: + explicit TestFlightServer() {}; + + void updateTable(std::string name, std::shared_ptr table) { + tables_.emplace(std::move(name), std::move(table)); + } + + void removeTable(const std::string& name) { + tables_.erase(name); + } + + arrow::Status DoGet( + const arrow::flight::ServerCallContext& context, + const arrow::flight::Ticket& request, + std::unique_ptr* stream) override; + + private: + std::unordered_map> tables_; +}; + +} // namespace facebook::presto::connector::arrow_flight::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.cpp new file mode 100644 index 0000000000000..62c581a43be53 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.cpp @@ -0,0 +1,91 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "Utils.h" +#include +#include + +namespace facebook::presto::connector::arrow_flight::test { + +std::shared_ptr makeDecimalArray( + const std::vector& decimalValues, + int precision, + int scale) { + auto decimalType = arrow::decimal(precision, scale); + auto builder = + arrow::Decimal128Builder(decimalType, arrow::default_memory_pool()); + + for (const auto& value : decimalValues) { + arrow::Decimal128 dec(value); + AFC_RAISE_NOT_OK(builder.Append(dec)); + } + AFC_RETURN_OR_RAISE(builder.Finish()); +} + +std::shared_ptr makeTimestampArray( + const std::vector& values, + arrow::TimeUnit::type timeUnit, + arrow::MemoryPool* memory_pool) { + arrow::TimestampBuilder builder(arrow::timestamp(timeUnit), memory_pool); + AFC_RAISE_NOT_OK(builder.AppendValues(values)); + AFC_RETURN_OR_RAISE(builder.Finish()); +} + +std::shared_ptr makeStringArray( + const std::vector& values) { + auto builder = arrow::StringBuilder{}; + AFC_RAISE_NOT_OK(builder.AppendValues(values)); + AFC_RETURN_OR_RAISE(builder.Finish()); +} + +std::shared_ptr makeBooleanArray( + const std::vector& values) { + auto builder = arrow::BooleanBuilder{}; + AFC_RAISE_NOT_OK(builder.AppendValues(values)); + AFC_RETURN_OR_RAISE(builder.Finish()); +} + +auto makeRecordBatch( + const std::vector& names, + const arrow::ArrayVector& arrays) { + VELOX_CHECK_EQ(names.size(), arrays.size()); + + auto nrows = (arrays.size() > 0) ? (arrays[0]->length()) : 0; + arrow::FieldVector fields{}; + for (int i = 0; i < arrays.size(); i++) { + VELOX_CHECK_EQ(arrays[i]->length(), nrows); + fields.push_back( + std::make_shared(names[i], arrays[i]->type())); + } + + auto schema = arrow::schema(fields); + return arrow::RecordBatch::Make(schema, nrows, arrays); +} + +std::shared_ptr makeArrowTable( + const std::vector& names, + const arrow::ArrayVector& arrays) { + AFC_RETURN_OR_RAISE( + arrow::Table::FromRecordBatches({makeRecordBatch(names, arrays)})); +} + +std::string readFile(const std::string& path) { + std::ifstream file(path); + VELOX_CHECK( + file.is_open(), "Could not open file \"{}\": {}", path, strerror(errno)); + return std::string( + (std::istreambuf_iterator(file)), + (std::istreambuf_iterator())); +} + +} // namespace facebook::presto::connector::arrow_flight::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.h new file mode 100644 index 0000000000000..b092d08b02170 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.h @@ -0,0 +1,54 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "arrow/api.h" +#include "presto_cpp/main/connectors/arrow_flight/Macros.h" +#include "velox/common/base/Exceptions.h" + +namespace facebook::presto::connector::arrow_flight::test { + +template +auto makeNumericArray(const std::vector& values) { + auto builder = arrow::NumericBuilder{}; + AFC_RAISE_NOT_OK(builder.AppendValues(values)); + AFC_RETURN_OR_RAISE(builder.Finish()); +} + +std::shared_ptr makeDecimalArray( + const std::vector& decimalValues, + int precision, + int scale); + +std::shared_ptr makeTimestampArray( + const std::vector& values, + arrow::TimeUnit::type timeUnit, + arrow::MemoryPool* memory_pool = arrow::default_memory_pool()); + +std::shared_ptr makeStringArray( + const std::vector& values); + +std::shared_ptr makeBooleanArray(const std::vector& values); + +auto makeRecordBatch( + const std::vector& names, + const arrow::ArrayVector& arrays); + +std::shared_ptr makeArrowTable( + const std::vector& names, + const arrow::ArrayVector& arrays); + +std::string readFile(const std::string& path); + +} // namespace facebook::presto::connector::arrow_flight::test diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxConnector.cpp b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxConnector.cpp index 29017b930a0b8..5899215bc3dc7 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxConnector.cpp +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxConnector.cpp @@ -1097,7 +1097,8 @@ velox::connector::hive::iceberg::FileContent toVeloxFileContent( std::unique_ptr HivePrestoToVeloxConnector::toVeloxSplit( const protocol::ConnectorId& catalogId, - const protocol::ConnectorSplit* const connectorSplit) const { + const protocol::ConnectorSplit* const connectorSplit, + const std::map& extraCredentials) const { auto hiveSplit = dynamic_cast(connectorSplit); VELOX_CHECK_NOT_NULL( @@ -1331,7 +1332,8 @@ HivePrestoToVeloxConnector::createConnectorProtocol() const { std::unique_ptr IcebergPrestoToVeloxConnector::toVeloxSplit( const protocol::ConnectorId& catalogId, - const protocol::ConnectorSplit* const connectorSplit) const { + const protocol::ConnectorSplit* const connectorSplit, + const std::map& extraCredentials) const { auto icebergSplit = dynamic_cast(connectorSplit); VELOX_CHECK_NOT_NULL( @@ -1482,7 +1484,8 @@ IcebergPrestoToVeloxConnector::createConnectorProtocol() const { std::unique_ptr TpchPrestoToVeloxConnector::toVeloxSplit( const protocol::ConnectorId& catalogId, - const protocol::ConnectorSplit* const connectorSplit) const { + const protocol::ConnectorSplit* const connectorSplit, + const std::map& extraCredentials) const { auto tpchSplit = dynamic_cast(connectorSplit); VELOX_CHECK_NOT_NULL( diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxConnector.h b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxConnector.h index 6d80751778e6a..23a3387f97279 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxConnector.h +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxConnector.h @@ -45,7 +45,9 @@ class PrestoToVeloxConnector { [[nodiscard]] virtual std::unique_ptr toVeloxSplit( const protocol::ConnectorId& catalogId, - const protocol::ConnectorSplit* connectorSplit) const = 0; + const protocol::ConnectorSplit* connectorSplit, + const std::map& extraCredentials = {}) + const = 0; [[nodiscard]] virtual std::unique_ptr toVeloxColumnHandle( @@ -115,7 +117,9 @@ class HivePrestoToVeloxConnector final : public PrestoToVeloxConnector { std::unique_ptr toVeloxSplit( const protocol::ConnectorId& catalogId, - const protocol::ConnectorSplit* connectorSplit) const final; + const protocol::ConnectorSplit* connectorSplit, + const std::map& extraCredentials = {}) + const final; std::unique_ptr toVeloxColumnHandle( const protocol::ColumnHandle* column, @@ -166,7 +170,9 @@ class IcebergPrestoToVeloxConnector final : public PrestoToVeloxConnector { std::unique_ptr toVeloxSplit( const protocol::ConnectorId& catalogId, - const protocol::ConnectorSplit* connectorSplit) const final; + const protocol::ConnectorSplit* connectorSplit, + const std::map& extraCredentials = {}) + const final; std::unique_ptr toVeloxColumnHandle( const protocol::ColumnHandle* column, @@ -192,7 +198,9 @@ class TpchPrestoToVeloxConnector final : public PrestoToVeloxConnector { std::unique_ptr toVeloxSplit( const protocol::ConnectorId& catalogId, - const protocol::ConnectorSplit* connectorSplit) const final; + const protocol::ConnectorSplit* connectorSplit, + const std::map& extraCredentials = {}) + const final; std::unique_ptr toVeloxColumnHandle( const protocol::ColumnHandle* column, diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxSplit.cpp b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxSplit.cpp index 47522c7842098..62d6f0e199efe 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxSplit.cpp +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxSplit.cpp @@ -20,7 +20,8 @@ using namespace facebook::velox; namespace facebook::presto { velox::exec::Split toVeloxSplit( - const presto::protocol::ScheduledSplit& scheduledSplit) { + const presto::protocol::ScheduledSplit& scheduledSplit, + const std::map& extraCredentials) { const auto& connectorSplit = scheduledSplit.split.connectorSplit; const auto splitGroupId = scheduledSplit.split.lifespan.isgroup ? scheduledSplit.split.lifespan.groupid @@ -40,7 +41,7 @@ velox::exec::Split toVeloxSplit( auto& connector = getPrestoToVeloxConnector(connectorSplit->_type); auto veloxSplit = connector.toVeloxSplit( - scheduledSplit.split.connectorId, connectorSplit.get()); + scheduledSplit.split.connectorId, connectorSplit.get(), extraCredentials); return velox::exec::Split(std::move(veloxSplit), splitGroupId); } diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxSplit.h b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxSplit.h index ff8692c2700a0..5b93a767ce029 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxSplit.h +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxSplit.h @@ -21,6 +21,7 @@ namespace facebook::presto { // Creates and returns exec::Split (with connector::ConnectorSplit inside) based // on the given protocol split. velox::exec::Split toVeloxSplit( - const presto::protocol::ScheduledSplit& scheduledSplit); + const presto::protocol::ScheduledSplit& scheduledSplit, + const std::map& extraCredentials = {}); } // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/presto_protocol/Makefile b/presto-native-execution/presto_cpp/presto_protocol/Makefile index 3ee2b4e802b81..09b43df28b4f5 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/Makefile +++ b/presto-native-execution/presto_cpp/presto_protocol/Makefile @@ -45,14 +45,23 @@ presto_protocol-cpp: presto_protocol-json chevron -d connector/tpch/presto_protocol_tpch.json connector/tpch/presto_protocol-json-hpp.mustache >> connector/tpch/presto_protocol_tpch.h clang-format -style=file -i connector/tpch/presto_protocol_tpch.h connector/tpch/presto_protocol_tpch.cpp + # build arrow_flight connector related structs + echo "// DO NOT EDIT : This file is generated by chevron" > connector/arrow_flight/presto_protocol_arrow_flight.cpp + chevron -d connector/arrow_flight/presto_protocol_arrow_flight.json connector/arrow_flight/presto_protocol-json-cpp.mustache >> connector/arrow_flight/presto_protocol_arrow_flight.cpp + echo "// DO NOT EDIT : This file is generated by chevron" > connector/arrow_flight/presto_protocol_arrow_flight.h + chevron -d connector/arrow_flight/presto_protocol_arrow_flight.json connector/arrow_flight/presto_protocol-json-hpp.mustache >> connector/arrow_flight/presto_protocol_arrow_flight.h + clang-format -style=file -i connector/arrow_flight/presto_protocol_arrow_flight.h connector/arrow_flight/presto_protocol_arrow_flight.cpp + presto_protocol-json: ./java-to-struct-json.py --config core/presto_protocol_core.yml core/special/*.java core/special/*.inc -j | jq . > core/presto_protocol_core.json ./java-to-struct-json.py --config connector/hive/presto_protocol_hive.yml connector/hive/special/*.inc -j | jq . > connector/hive/presto_protocol_hive.json ./java-to-struct-json.py --config connector/iceberg/presto_protocol_iceberg.yml connector/iceberg/special/*.inc -j | jq . > connector/iceberg/presto_protocol_iceberg.json ./java-to-struct-json.py --config connector/tpch/presto_protocol_tpch.yml connector/tpch/special/*.inc -j | jq . > connector/tpch/presto_protocol_tpch.json + ./java-to-struct-json.py --config connector/arrow_flight/presto_protocol_arrow_flight.yml connector/arrow_flight/special/*.inc -j | jq . > connector/arrow_flight/presto_protocol_arrow_flight.json presto_protocol.proto: presto_protocol-json pystache presto_protocol-protobuf.mustache core/presto_protocol_core.json > core/presto_protocol_core.proto pystache presto_protocol-protobuf.mustache connector/hive/presto_protocol_hive.json > connector/hive/presto_protocol_hive.proto pystache presto_protocol-protobuf.mustache connector/iceberg/presto_protocol_iceberg.json > connector/iceberg/presto_protocol_iceberg.proto pystache presto_protocol-protobuf.mustache connector/tpch/presto_protocol_tpch.json > connector/tpch/presto_protocol_tpch.proto + pystache presto_protocol-protobuf.mustache connector/arrow_flight/presto_protocol_arrow_flight.json > connector/arrow_flight/presto_protocol_arrow_flight.proto diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/ArrowFlightConnectorProtocol.h b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/ArrowFlightConnectorProtocol.h new file mode 100644 index 0000000000000..95cda16115695 --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/ArrowFlightConnectorProtocol.h @@ -0,0 +1,29 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.h" +#include "presto_cpp/presto_protocol/core/ConnectorProtocol.h" + +namespace facebook::presto::protocol::arrow_flight { +using ArrowConnectorProtocol = ConnectorProtocolTemplate< + ArrowTableHandle, + ArrowTableLayoutHandle, + ArrowColumnHandle, + NotImplemented, + NotImplemented, + ArrowSplit, + NotImplemented, + ArrowTransactionHandle, + NotImplemented>; +} // namespace facebook::presto::protocol::arrow_flight diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol-json-cpp.mustache b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol-json-cpp.mustache new file mode 100644 index 0000000000000..b6ecb68507285 --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol-json-cpp.mustache @@ -0,0 +1,150 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// presto_protocol.prolog.cpp +// + +{{#.}} +{{#comment}} +{{comment}} +{{/comment}} +{{/.}} + +#include + +#include "presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.h" +using namespace std::string_literals; + +// ArrowTransactionHandle is special since +// the corresponding class in Java is an enum. + +namespace facebook::presto::protocol::arrow_flight { + +void to_json(json& j, const ArrowTransactionHandle& p) { + j = json::array(); + j.push_back(p._type); + j.push_back(p.instance); +} + +void from_json(const json& j, ArrowTransactionHandle& p) { + j[0].get_to(p._type); + j[1].get_to(p.instance); +} +} // namespace facebook::presto::protocol::arrow_flight +{{#.}} +{{#cinc}} +{{&cinc}} +{{/cinc}} +{{^cinc}} +{{#struct}} +namespace facebook::presto::protocol::arrow_flight { + {{#super_class}} + {{&class_name}}::{{&class_name}}() noexcept { + _type = "{{json_key}}"; + } + {{/super_class}} + + void to_json(json& j, const {{&class_name}}& p) { + j = json::object(); + {{#super_class}} + j["@type"] = "{{&json_key}}"; + {{/super_class}} + {{#fields}} + to_json_key(j, "{{&field_name}}", p.{{field_name}}, "{{&class_name}}", "{{&field_text}}", "{{&field_name}}"); + {{/fields}} + } + + void from_json(const json& j, {{&class_name}}& p) { + {{#super_class}} + p._type = j["@type"]; + {{/super_class}} + {{#fields}} + from_json_key(j, "{{&field_name}}", p.{{field_name}}, "{{&class_name}}", "{{&field_text}}", "{{&field_name}}"); + {{/fields}} + } +} +{{/struct}} +{{#enum}} +namespace facebook::presto::protocol::arrow_flight { + //Loosly copied this here from NLOHMANN_JSON_SERIALIZE_ENUM() + + // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays + static const std::pair<{{&class_name}}, json> + {{&class_name}}_enum_table[] = { // NOLINT: cert-err58-cpp + {{#elements}} + { {{&class_name}}::{{&element}}, "{{&element}}" }{{^_last}},{{/_last}} + {{/elements}} + }; + void to_json(json& j, const {{&class_name}}& e) + { + static_assert(std::is_enum<{{&class_name}}>::value, "{{&class_name}} must be an enum!"); + const auto* it = std::find_if(std::begin({{&class_name}}_enum_table), std::end({{&class_name}}_enum_table), + [e](const std::pair<{{&class_name}}, json>& ej_pair) -> bool + { + return ej_pair.first == e; + }); + j = ((it != std::end({{&class_name}}_enum_table)) ? it : std::begin({{&class_name}}_enum_table))->second; + } + void from_json(const json& j, {{&class_name}}& e) + { + static_assert(std::is_enum<{{&class_name}}>::value, "{{&class_name}} must be an enum!"); + const auto* it = std::find_if(std::begin({{&class_name}}_enum_table), std::end({{&class_name}}_enum_table), + [&j](const std::pair<{{&class_name}}, json>& ej_pair) -> bool + { + return ej_pair.second == j; + }); + e = ((it != std::end({{&class_name}}_enum_table)) ? it : std::begin({{&class_name}}_enum_table))->first; + } +} +{{/enum}} +{{#abstract}} +namespace facebook::presto::protocol::arrow_flight { + void to_json(json& j, const std::shared_ptr<{{&class_name}}>& p) { + if ( p == nullptr ) { + return; + } + String type = p->_type; + + {{#subclasses}} + if ( type == "{{&key}}" ) { + j = *std::static_pointer_cast<{{&type}}>(p); + return; + } + {{/subclasses}} + + throw TypeError(type + " no abstract type {{&class_name}} {{&key}}"); + } + + void from_json(const json& j, std::shared_ptr<{{&class_name}}>& p) { + String type; + try { + type = p->getSubclassKey(j); + } catch (json::parse_error &e) { + throw ParseError(std::string(e.what()) + " {{&class_name}} {{&key}} {{&class_name}}"); + } + + {{#subclasses}} + if ( type == "{{&key}}" ) { + std::shared_ptr<{{&type}}> k = std::make_shared<{{&type}}>(); + j.get_to(*k); + p = std::static_pointer_cast<{{&class_name}}>(k); + return; + } + {{/subclasses}} + + throw TypeError(type + " no abstract type {{&class_name}} {{&key}}"); + } +} +{{/abstract}} +{{/cinc}} +{{/.}} diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol-json-hpp.mustache b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol-json-hpp.mustache new file mode 100644 index 0000000000000..be08bd9e491c2 --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol-json-hpp.mustache @@ -0,0 +1,76 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +{{#.}} +{{#comment}} +{{comment}} +{{/comment}} +{{/.}} + +#include +#include +#include +#include + +#include "presto_cpp/external/json/nlohmann/json.hpp" +#include "presto_cpp/presto_protocol/core/presto_protocol_core.h" + +// ArrowTransactionHandle is special since +// the corresponding class in Java is an enum. + +namespace facebook::presto::protocol::arrow_flight { + +struct ArrowTransactionHandle : public ConnectorTransactionHandle { + String instance = {}; +}; + +void to_json(json& j, const ArrowTransactionHandle& p); + +void from_json(const json& j, ArrowTransactionHandle& p); + +} // namespace facebook::presto::protocol::arrow_flight +{{#.}} +{{#hinc}} +{{&hinc}} +{{/hinc}} +{{^hinc}} +{{#struct}} +namespace facebook::presto::protocol::arrow_flight { + struct {{class_name}} {{#super_class}}: public {{super_class}}{{/super_class}}{ + {{#fields}} + {{#field_local}}{{#optional}}std::shared_ptr<{{/optional}}{{&field_text}}{{#optional}}>{{/optional}} {{&field_name}} = {};{{/field_local}} + {{/fields}} + + {{#super_class}} + {{class_name}}() noexcept; + {{/super_class}} + }; + void to_json(json& j, const {{class_name}}& p); + void from_json(const json& j, {{class_name}}& p); +} +{{/struct}} +{{#enum}} +namespace facebook::presto::protocol::arrow_flight { + enum class {{class_name}} { + {{#elements}} + {{&element}}{{^_last}},{{/_last}} + {{/elements}} + }; + extern void to_json(json& j, const {{class_name}}& e); + extern void from_json(const json& j, {{class_name}}& e); +} +{{/enum}} +{{/hinc}} +{{/.}} diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.cpp b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.cpp new file mode 100644 index 0000000000000..93e3318911e86 --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.cpp @@ -0,0 +1,231 @@ +// DO NOT EDIT : This file is generated by chevron +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// presto_protocol.prolog.cpp +// + +// This file is generated DO NOT EDIT @generated + +#include + +#include "presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.h" +using namespace std::string_literals; + +// ArrowTransactionHandle is special since +// the corresponding class in Java is an enum. + +namespace facebook::presto::protocol::arrow_flight { + +void to_json(json& j, const ArrowTransactionHandle& p) { + j = json::array(); + j.push_back(p._type); + j.push_back(p.instance); +} + +void from_json(const json& j, ArrowTransactionHandle& p) { + j[0].get_to(p._type); + j[1].get_to(p.instance); +} +} // namespace facebook::presto::protocol::arrow_flight +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +namespace facebook::presto::protocol::arrow_flight { +ArrowColumnHandle::ArrowColumnHandle() noexcept { + _type = "arrow-flight"; +} + +void to_json(json& j, const ArrowColumnHandle& p) { + j = json::object(); + j["@type"] = "arrow-flight"; + to_json_key( + j, + "columnName", + p.columnName, + "ArrowColumnHandle", + "String", + "columnName"); + to_json_key( + j, "columnType", p.columnType, "ArrowColumnHandle", "Type", "columnType"); +} + +void from_json(const json& j, ArrowColumnHandle& p) { + p._type = j["@type"]; + from_json_key( + j, + "columnName", + p.columnName, + "ArrowColumnHandle", + "String", + "columnName"); + from_json_key( + j, "columnType", p.columnType, "ArrowColumnHandle", "Type", "columnType"); +} +} // namespace facebook::presto::protocol::arrow_flight +namespace facebook::presto::protocol::arrow_flight { +ArrowSplit::ArrowSplit() noexcept { + _type = "arrow-flight"; +} + +void to_json(json& j, const ArrowSplit& p) { + j = json::object(); + j["@type"] = "arrow-flight"; + to_json_key( + j, "schemaName", p.schemaName, "ArrowSplit", "String", "schemaName"); + to_json_key(j, "tableName", p.tableName, "ArrowSplit", "String", "tableName"); + to_json_key(j, "ticket", p.ticket, "ArrowSplit", "String", "ticket"); + to_json_key( + j, + "locationUrls", + p.locationUrls, + "ArrowSplit", + "List", + "locationUrls"); +} + +void from_json(const json& j, ArrowSplit& p) { + p._type = j["@type"]; + from_json_key( + j, "schemaName", p.schemaName, "ArrowSplit", "String", "schemaName"); + from_json_key( + j, "tableName", p.tableName, "ArrowSplit", "String", "tableName"); + from_json_key(j, "ticket", p.ticket, "ArrowSplit", "String", "ticket"); + from_json_key( + j, + "locationUrls", + p.locationUrls, + "ArrowSplit", + "List", + "locationUrls"); +} +} // namespace facebook::presto::protocol::arrow_flight +namespace facebook::presto::protocol::arrow_flight { +ArrowTableHandle::ArrowTableHandle() noexcept { + _type = "arrow-flight"; +} + +void to_json(json& j, const ArrowTableHandle& p) { + j = json::object(); + j["@type"] = "arrow-flight"; + to_json_key(j, "schema", p.schema, "ArrowTableHandle", "String", "schema"); + to_json_key(j, "table", p.table, "ArrowTableHandle", "String", "table"); +} + +void from_json(const json& j, ArrowTableHandle& p) { + p._type = j["@type"]; + from_json_key(j, "schema", p.schema, "ArrowTableHandle", "String", "schema"); + from_json_key(j, "table", p.table, "ArrowTableHandle", "String", "table"); +} +} // namespace facebook::presto::protocol::arrow_flight +namespace facebook::presto::protocol::arrow_flight { +void to_json(json& j, const std::shared_ptr& p) { + if (p == nullptr) { + return; + } + String type = p->_type; + + if (type == "arrow-flight") { + j = *std::static_pointer_cast(p); + return; + } + + throw TypeError(type + " no abstract type ColumnHandle "); +} + +void from_json(const json& j, std::shared_ptr& p) { + String type; + try { + type = p->getSubclassKey(j); + } catch (json::parse_error& e) { + throw ParseError(std::string(e.what()) + " ColumnHandle ColumnHandle"); + } + + if (type == "arrow-flight") { + std::shared_ptr k = + std::make_shared(); + j.get_to(*k); + p = std::static_pointer_cast(k); + return; + } + + throw TypeError(type + " no abstract type ColumnHandle "); +} +} // namespace facebook::presto::protocol::arrow_flight +namespace facebook::presto::protocol::arrow_flight { +ArrowTableLayoutHandle::ArrowTableLayoutHandle() noexcept { + _type = "arrow-flight"; +} + +void to_json(json& j, const ArrowTableLayoutHandle& p) { + j = json::object(); + j["@type"] = "arrow-flight"; + to_json_key( + j, + "table", + p.table, + "ArrowTableLayoutHandle", + "ArrowTableHandle", + "table"); + to_json_key( + j, + "columnHandles", + p.columnHandles, + "ArrowTableLayoutHandle", + "List", + "columnHandles"); + to_json_key( + j, + "tupleDomain", + p.tupleDomain, + "ArrowTableLayoutHandle", + "TupleDomain>", + "tupleDomain"); +} + +void from_json(const json& j, ArrowTableLayoutHandle& p) { + p._type = j["@type"]; + from_json_key( + j, + "table", + p.table, + "ArrowTableLayoutHandle", + "ArrowTableHandle", + "table"); + from_json_key( + j, + "columnHandles", + p.columnHandles, + "ArrowTableLayoutHandle", + "List", + "columnHandles"); + from_json_key( + j, + "tupleDomain", + p.tupleDomain, + "ArrowTableLayoutHandle", + "TupleDomain>", + "tupleDomain"); +} +} // namespace facebook::presto::protocol::arrow_flight diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.h b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.h new file mode 100644 index 0000000000000..5fed8f087cd2f --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.h @@ -0,0 +1,104 @@ +// DO NOT EDIT : This file is generated by chevron +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +// This file is generated DO NOT EDIT @generated + +#include +#include +#include +#include + +#include "presto_cpp/external/json/nlohmann/json.hpp" +#include "presto_cpp/presto_protocol/core/presto_protocol_core.h" + +// ArrowTransactionHandle is special since +// the corresponding class in Java is an enum. + +namespace facebook::presto::protocol::arrow_flight { + +struct ArrowTransactionHandle : public ConnectorTransactionHandle { + String instance = {}; +}; + +void to_json(json& j, const ArrowTransactionHandle& p); + +void from_json(const json& j, ArrowTransactionHandle& p); + +} // namespace facebook::presto::protocol::arrow_flight +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// ArrowColumnHandle is special since it does not require all the +// properties from the corresponding java class + +namespace facebook::presto::protocol::arrow_flight { +struct ArrowColumnHandle : public ColumnHandle { + String columnName = {}; + Type columnType = {}; + + ArrowColumnHandle() noexcept; + + bool operator<(const ColumnHandle& o) const override { + return columnName < dynamic_cast(o).columnName; + } +}; +void to_json(json& j, const ArrowColumnHandle& p); +void from_json(const json& j, ArrowColumnHandle& p); +} // namespace facebook::presto::protocol::arrow_flight +namespace facebook::presto::protocol::arrow_flight { +struct ArrowSplit : public ConnectorSplit { + String schemaName = {}; + String tableName = {}; + String ticket = {}; + List locationUrls = {}; + + ArrowSplit() noexcept; +}; +void to_json(json& j, const ArrowSplit& p); +void from_json(const json& j, ArrowSplit& p); +} // namespace facebook::presto::protocol::arrow_flight +namespace facebook::presto::protocol::arrow_flight { +struct ArrowTableHandle : public ConnectorTableHandle { + String schema = {}; + String table = {}; + + ArrowTableHandle() noexcept; +}; +void to_json(json& j, const ArrowTableHandle& p); +void from_json(const json& j, ArrowTableHandle& p); +} // namespace facebook::presto::protocol::arrow_flight +namespace facebook::presto::protocol::arrow_flight { +struct ArrowTableLayoutHandle : public ConnectorTableLayoutHandle { + ArrowTableHandle table = {}; + List columnHandles = {}; + TupleDomain> tupleDomain = {}; + + ArrowTableLayoutHandle() noexcept; +}; +void to_json(json& j, const ArrowTableLayoutHandle& p); +void from_json(const json& j, ArrowTableLayoutHandle& p); +} // namespace facebook::presto::protocol::arrow_flight diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.yml b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.yml new file mode 100644 index 0000000000000..f34f6068eb777 --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.yml @@ -0,0 +1,40 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +AbstractClasses: + ColumnHandle: + super: JsonEncodedSubclass + comparable: true + subclasses: + - { name: ArrowColumnHandle, key: arrow-flight } + + ConnectorTableHandle: + super: JsonEncodedSubclass + subclasses: + - { name: ArrowTableHandle, key: arrow-flight } + + ConnectorTableLayoutHandle: + super: JsonEncodedSubclass + subclasses: + - { name: ArrowTableLayoutHandle, key: arrow-flight } + + ConnectorSplit: + super: JsonEncodedSubclass + subclasses: + - { name: ArrowSplit, key: arrow-flight } + +JavaClasses: + - presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowColumnHandle.java + - presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableHandle.java + - presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableLayoutHandle.java + - presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplit.java diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/special/ArrowColumnHandle.cpp.inc b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/special/ArrowColumnHandle.cpp.inc new file mode 100644 index 0000000000000..d6a2c7f5fcc18 --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/special/ArrowColumnHandle.cpp.inc @@ -0,0 +1,46 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +namespace facebook::presto::protocol::arrow_flight { +ArrowColumnHandle::ArrowColumnHandle() noexcept { + _type = "arrow-flight"; +} + +void to_json(json& j, const ArrowColumnHandle& p) { + j = json::object(); + j["@type"] = "arrow-flight"; + to_json_key( + j, + "columnName", + p.columnName, + "ArrowColumnHandle", + "String", + "columnName"); + to_json_key( + j, "columnType", p.columnType, "ArrowColumnHandle", "Type", "columnType"); +} + +void from_json(const json& j, ArrowColumnHandle& p) { + p._type = j["@type"]; + from_json_key( + j, + "columnName", + p.columnName, + "ArrowColumnHandle", + "String", + "columnName"); + from_json_key( + j, "columnType", p.columnType, "ArrowColumnHandle", "Type", "columnType"); +} +} // namespace facebook::presto::protocol::arrow_flight diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/special/ArrowColumnHandle.hpp.inc b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/special/ArrowColumnHandle.hpp.inc new file mode 100644 index 0000000000000..b4c3872dddd6f --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/special/ArrowColumnHandle.hpp.inc @@ -0,0 +1,31 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// ArrowColumnHandle is special since it does not require all the +// properties from the corresponding java class + +namespace facebook::presto::protocol::arrow_flight { +struct ArrowColumnHandle : public ColumnHandle { + String columnName = {}; + Type columnType = {}; + + ArrowColumnHandle() noexcept; + + bool operator<(const ColumnHandle& o) const override { + return columnName < dynamic_cast(o).columnName; + } +}; +void to_json(json& j, const ArrowColumnHandle& p); +void from_json(const json& j, ArrowColumnHandle& p); +} // namespace facebook::presto::protocol::arrow_flight diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/special/ArrowTransactionHandle.cpp.inc b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/special/ArrowTransactionHandle.cpp.inc new file mode 100644 index 0000000000000..a93325f5b154a --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/special/ArrowTransactionHandle.cpp.inc @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// ArrowTransactionHandle is special since +// the corresponding class in Java is an enum. + +namespace facebook::presto::protocol::arrow_flight { + +void to_json(json& j, const ArrowTransactionHandle& p) { + j = json::array(); + j.push_back(p._type); + j.push_back(p.instance); +} + +void from_json(const json& j, ArrowTransactionHandle& p) { + j[0].get_to(p._type); + j[1].get_to(p.instance); +} +} // namespace facebook::presto::protocol::arrow_flight diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/special/ArrowTransactionHandle.hpp.inc b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/special/ArrowTransactionHandle.hpp.inc new file mode 100644 index 0000000000000..dc573ca2e68cf --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/special/ArrowTransactionHandle.hpp.inc @@ -0,0 +1,28 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// ArrowTransactionHandle is special since +// the corresponding class in Java is an enum. + +namespace facebook::presto::protocol::arrow_flight { + +struct ArrowTransactionHandle : public ConnectorTransactionHandle { + String instance = {}; +}; + +void to_json(json& j, const ArrowTransactionHandle& p); + +void from_json(const json& j, ArrowTransactionHandle& p); + +} // namespace facebook::presto::protocol::arrow_flight diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp index 6bfc625743dba..2ef515165d099 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp @@ -1069,6 +1069,7 @@ void from_json(const json& j, std::shared_ptr& p) { */ // dependency TpchTransactionHandle +// dependency ArrowTransactionHandle namespace facebook::presto::protocol { void to_json(json& j, const std::shared_ptr& p) { diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h index cbce83539ca17..a42454d7359fb 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h @@ -67,21 +67,21 @@ extern const char* const PRESTO_ABORT_TASK_URL_PARAM; class Exception : public std::runtime_error { public: explicit Exception(const std::string& message) - : std::runtime_error(message){}; + : std::runtime_error(message) {}; }; class TypeError : public Exception { public: - explicit TypeError(const std::string& message) : Exception(message){}; + explicit TypeError(const std::string& message) : Exception(message) {}; }; class OutOfRange : public Exception { public: - explicit OutOfRange(const std::string& message) : Exception(message){}; + explicit OutOfRange(const std::string& message) : Exception(message) {}; }; class ParseError : public Exception { public: - explicit ParseError(const std::string& message) : Exception(message){}; + explicit ParseError(const std::string& message) : Exception(message) {}; }; using String = std::string; diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml index 51669ad6891e0..f9d3bf85049bb 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml @@ -54,6 +54,7 @@ AbstractClasses: - { name: IcebergColumnHandle, key: hive-iceberg } - { name: TpchColumnHandle, key: tpch } - { name: SystemColumnHandle, key: $system@system } + - { name: ArrowColumnHandle, key: arrow-flight } ConnectorPartitioningHandle: super: JsonEncodedSubclass @@ -69,6 +70,7 @@ AbstractClasses: - { name: IcebergTableHandle, key: hive-iceberg } - { name: TpchTableHandle, key: tpch } - { name: SystemTableHandle, key: $system@system } + - { name: ArrowTableHandle, key: arrow-flight } ConnectorOutputTableHandle: super: JsonEncodedSubclass @@ -96,6 +98,7 @@ AbstractClasses: - { name: IcebergTableLayoutHandle, key: hive-iceberg } - { name: TpchTableLayoutHandle, key: tpch } - { name: SystemTableLayoutHandle, key: $system@system } + - { name: ArrowTableLayoutHandle, key: arrow-flight } ConnectorMetadataUpdateHandle: super: JsonEncodedSubclass @@ -111,6 +114,7 @@ AbstractClasses: - { name: RemoteSplit, key: $remote } - { name: EmptySplit, key: $empty } - { name: SystemSplit, key: $system@system } + - { name: ArrowSplit, key: arrow-flight } ConnectorHistogram: super: JsonEncodedSubclass diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorTransactionHandle.cpp.inc b/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorTransactionHandle.cpp.inc index 8ec2a94e84bd9..1dfb17e4a908f 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorTransactionHandle.cpp.inc +++ b/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorTransactionHandle.cpp.inc @@ -13,6 +13,7 @@ */ // dependency TpchTransactionHandle +// dependency ArrowTransactionHandle namespace facebook::presto::protocol { void to_json(json& j, const std::shared_ptr& p) { diff --git a/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.cpp b/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.cpp index c15084817a434..24f24f27f87a3 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.cpp +++ b/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.cpp @@ -15,6 +15,7 @@ // DEPRECATED: This file is deprecated and will be removed in future versions. +#include "presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.cpp" #include "presto_cpp/presto_protocol/connector/hive/presto_protocol_hive.cpp" #include "presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.cpp" #include "presto_cpp/presto_protocol/connector/tpch/presto_protocol_tpch.cpp" diff --git a/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.h b/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.h index dd94975e3760d..c43ec92629f44 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.h +++ b/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.h @@ -16,6 +16,7 @@ // DEPRECATED: This file is deprecated and will be removed in future versions. +#include "presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.h" #include "presto_cpp/presto_protocol/connector/hive/presto_protocol_hive.h" #include "presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.h" #include "presto_cpp/presto_protocol/connector/tpch/presto_protocol_tpch.h" diff --git a/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.yml b/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.yml index 18c9afda02ece..83a18b28a72ad 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.yml +++ b/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.yml @@ -53,6 +53,7 @@ AbstractClasses: - { name: IcebergColumnHandle, key: hive-iceberg } - { name: TpchColumnHandle, key: tpch } - { name: SystemColumnHandle, key: $system@system } + - { name: ArrowColumnHandle, key: arrow-flight } ConnectorPartitioningHandle: super: JsonEncodedSubclass @@ -68,6 +69,7 @@ AbstractClasses: - { name: IcebergTableHandle, key: hive-iceberg } - { name: TpchTableHandle, key: tpch } - { name: SystemTableHandle, key: $system@system } + - { name: ArrowTableHandle, key: arrow-flight } ConnectorOutputTableHandle: super: JsonEncodedSubclass @@ -95,6 +97,7 @@ AbstractClasses: - { name: IcebergTableLayoutHandle, key: hive-iceberg } - { name: TpchTableLayoutHandle, key: tpch } - { name: SystemTableLayoutHandle, key: $system@system } + - { name: ArrowTableLayoutHandle, key: arrow-flight } ConnectorMetadataUpdateHandle: super: JsonEncodedSubclass @@ -110,6 +113,7 @@ AbstractClasses: - { name: RemoteSplit, key: $remote } - { name: EmptySplit, key: $empty } - { name: SystemSplit, key: $system@system } + - { name: ArrowSplit, key: arrow-flight } ConnectorHistogram: super: JsonEncodedSubclass @@ -365,3 +369,7 @@ JavaClasses: - presto-main/src/main/java/com/facebook/presto/connector/system/SystemTransactionHandle.java - presto-spi/src/main/java/com/facebook/presto/spi/function/AggregationFunctionMetadata.java - presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/json/JsonBasedUdfFunctionMetadata.java + - presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowColumnHandle.java + - presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableHandle.java + - presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableLayoutHandle.java + - presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplit.java diff --git a/presto-native-execution/scripts/setup-adapters.sh b/presto-native-execution/scripts/setup-adapters.sh index 6c36424ebf90c..532ec01d9e8f8 100755 --- a/presto-native-execution/scripts/setup-adapters.sh +++ b/presto-native-execution/scripts/setup-adapters.sh @@ -35,15 +35,73 @@ function install_prometheus_cpp { cmake_install -DBUILD_SHARED_LIBS=ON -DENABLE_PUSH=OFF -DENABLE_COMPRESSION=OFF } +function install_abseil { + # abseil-cpp + github_checkout abseil/abseil-cpp 20240116.2 --depth 1 + cmake_install \ + -DABSL_BUILD_TESTING=OFF \ + -DCMAKE_CXX_STANDARD=17 \ + -DABSL_PROPAGATE_CXX_STD=ON \ + -DABSL_ENABLE_INSTALL=ON +} + +function install_grpc { + # grpc + github_checkout grpc/grpc v1.48.1 --depth 1 + cmake_install \ + -DgRPC_BUILD_TESTS=OFF \ + -DgRPC_ABSL_PROVIDER=package \ + -DgRPC_ZLIB_PROVIDER=package \ + -DgRPC_CARES_PROVIDER=package \ + -DgRPC_RE2_PROVIDER=package \ + -DgRPC_SSL_PROVIDER=package \ + -DgRPC_PROTOBUF_PROVIDER=package \ + -DgRPC_INSTALL=ON +} + +function install_arrow_flight { + ARROW_VERSION="15.0.0" + if [[ "$OSTYPE" == "linux-gnu"* ]]; then + export INSTALL_PREFIX=${INSTALL_PREFIX:-"/usr/local"} + LINUX_DISTRIBUTION=$(. /etc/os-release && echo ${ID}) + if [[ "$LINUX_DISTRIBUTION" == "ubuntu" || "$LINUX_DISTRIBUTION" == "debian" ]]; then + SUDO="${SUDO:-"sudo --preserve-env"}" + ${SUDO} apt install -y libc-ares-dev + ${SUDO} ldconfig -v 2>/dev/null | grep "${INSTALL_PREFIX}/lib" || \ + echo "${INSTALL_PREFIX}/lib" | ${SUDO} tee /etc/ld.so.conf.d/local-libraries.conf > /dev/null \ + && ${SUDO} ldconfig + else + dnf -y install c-ares-devel + ldconfig -v 2>/dev/null | grep "${INSTALL_PREFIX}/lib" || \ + echo "${INSTALL_PREFIX}/lib" | tee /etc/ld.so.conf.d/local-libraries.conf > /dev/null \ + && ldconfig + fi + else + # The installation script for the Arrow Flight connector currently works only on Linux distributions. + return 0 + fi + + install_abseil + install_grpc + + wget_and_untar https://github.com/apache/arrow/archive/apache-arrow-${ARROW_VERSION}.tar.gz arrow + cmake_install_dir arrow/cpp \ + -DARROW_FLIGHT=ON \ + -DARROW_BUILD_BENCHMARKS=ON \ + -DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX} +} + cd "${DEPENDENCY_DIR}" || exit install_jwt=0 install_prometheus_cpp=0 +install_arrow_flight=0 if [ "$#" -eq 0 ]; then # Install all adapters by default install_jwt=1 install_prometheus_cpp=1 + install_arrow_flight=1 fi while [[ $# -gt 0 ]]; do @@ -56,6 +114,10 @@ while [[ $# -gt 0 ]]; do install_prometheus_cpp=1; shift ;; + arrow_flight) + install_arrow_flight=1; + shift + ;; *) echo "ERROR: Unknown option $1! will be ignored!" shift @@ -72,6 +134,10 @@ if [ $install_prometheus_cpp -eq 1 ]; then install_prometheus_cpp fi +if [ $install_arrow_flight -eq 1 ]; then + install_arrow_flight +fi + _ret=$? if [ $_ret -eq 0 ] ; then echo "All deps for Presto adapters installed!"