diff --git a/presto-native-execution/presto_cpp/main/PrestoServer.cpp b/presto-native-execution/presto_cpp/main/PrestoServer.cpp index 375e049d5f78c..80b3531ceb05b 100644 --- a/presto-native-execution/presto_cpp/main/PrestoServer.cpp +++ b/presto-native-execution/presto_cpp/main/PrestoServer.cpp @@ -1264,14 +1264,12 @@ std::vector PrestoServer::registerVeloxConnectors( // make sure connector type is supported getPrestoToVeloxConnector(connectorName); - - std::shared_ptr connector = - velox::connector::getConnectorFactory(connectorName) - ->newConnector( - catalogName, - std::move(properties), - connectorIoExecutor_.get(), - connectorCpuExecutor_.get()); + auto connector = getConnectorFactory(connectorName) + ->newConnector( + catalogName, + std::move(properties), + connectorIoExecutor_.get(), + connectorCpuExecutor_.get()); velox::connector::registerConnector(connector); } } diff --git a/presto-native-execution/presto_cpp/main/connectors/Registration.cpp b/presto-native-execution/presto_cpp/main/connectors/Registration.cpp index d6f6555fb8a22..098ccea00db54 100644 --- a/presto-native-execution/presto_cpp/main/connectors/Registration.cpp +++ b/presto-native-execution/presto_cpp/main/connectors/Registration.cpp @@ -28,44 +28,45 @@ namespace { constexpr char const* kHiveHadoop2ConnectorName = "hive-hadoop2"; constexpr char const* kIcebergConnectorName = "iceberg"; -void registerConnectorFactories() { - // These checks for connector factories can be removed after we remove the - // registrations from the Velox library. - if (!velox::connector::hasConnectorFactory( - velox::connector::hive::HiveConnectorFactory::kHiveConnectorName)) { - velox::connector::registerConnectorFactory( - std::make_shared()); - velox::connector::registerConnectorFactory( - std::make_shared( - kHiveHadoop2ConnectorName)); - } - if (!velox::connector::hasConnectorFactory( - velox::connector::tpch::TpchConnectorFactory::kTpchConnectorName)) { - velox::connector::registerConnectorFactory( - std::make_shared()); - } - - // Register Velox connector factory for iceberg. - // The iceberg catalog is handled by the hive connector factory. - if (!velox::connector::hasConnectorFactory(kIcebergConnectorName)) { - velox::connector::registerConnectorFactory( - std::make_shared( - kIcebergConnectorName)); - } - +const std::unordered_map< + std::string, + const std::shared_ptr>& +connectorFactories() { + static const std::unordered_map< + std::string, + const std::shared_ptr> + factories = { + {velox::connector::hive::HiveConnectorFactory::kHiveConnectorName, + std::make_shared()}, + {kHiveHadoop2ConnectorName, + std::make_shared( + kHiveHadoop2ConnectorName)}, + {velox::connector::tpch::TpchConnectorFactory::kTpchConnectorName, + std::make_shared()}, + {kIcebergConnectorName, + std::make_shared( + kIcebergConnectorName)}, #ifdef PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR - if (!velox::connector::hasConnectorFactory( - ArrowFlightConnectorFactory::kArrowFlightConnectorName)) { - velox::connector::registerConnectorFactory( - std::make_shared()); - } + {ArrowFlightConnectorFactory::kArrowFlightConnectorName, + std::make_shared()}, #endif + }; + return factories; } + } // namespace -void registerConnectors() { - registerConnectorFactories(); +velox::connector::ConnectorFactory* getConnectorFactory( + const std::string& connectorName) { + auto it = connectorFactories().find(connectorName); + VELOX_CHECK( + it != connectorFactories().end(), + "ConnectorFactory with name '{}' not registered", + connectorName); + return it->second.get(); +} +void registerConnectors() { registerPrestoToVeloxConnector(std::make_unique( velox::connector::hive::HiveConnectorFactory::kHiveConnectorName)); registerPrestoToVeloxConnector( diff --git a/presto-native-execution/presto_cpp/main/connectors/Registration.h b/presto-native-execution/presto_cpp/main/connectors/Registration.h index c95aefaacfcaa..ee46dce009b9f 100644 --- a/presto-native-execution/presto_cpp/main/connectors/Registration.h +++ b/presto-native-execution/presto_cpp/main/connectors/Registration.h @@ -13,8 +13,18 @@ */ #pragma once +#include + +// Forward declaration for ConnectorFactory. +namespace facebook::velox::connector { +class ConnectorFactory; +} // namespace facebook::velox::connector + namespace facebook::presto { +velox::connector::ConnectorFactory* getConnectorFactory( + const std::string& connectorName); + void registerConnectors(); } // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightConnectorTestBase.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightConnectorTestBase.cpp index 1fecf9e31977a..2b692880f073d 100644 --- a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightConnectorTestBase.cpp +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightConnectorTestBase.cpp @@ -25,16 +25,9 @@ namespace facebook::presto::test { void ArrowFlightConnectorTestBase::SetUp() { OperatorTestBase::SetUp(); - - if (!velox::connector::hasConnectorFactory( - presto::ArrowFlightConnectorFactory::kArrowFlightConnectorName)) { - velox::connector::registerConnectorFactory( - std::make_shared()); - } + presto::ArrowFlightConnectorFactory factory; velox::connector::registerConnector( - velox::connector::getConnectorFactory( - ArrowFlightConnectorFactory::kArrowFlightConnectorName) - ->newConnector(kFlightConnectorId, config_)); + factory.newConnector(kFlightConnectorId, config_)); ArrowFlightConfig config(config_); if (config.defaultServerPort().has_value()) { diff --git a/presto-native-execution/presto_cpp/main/tests/ServerOperationTest.cpp b/presto-native-execution/presto_cpp/main/tests/ServerOperationTest.cpp index 68c820486d953..02c47b8cff7a7 100644 --- a/presto-native-execution/presto_cpp/main/tests/ServerOperationTest.cpp +++ b/presto-native-execution/presto_cpp/main/tests/ServerOperationTest.cpp @@ -140,18 +140,11 @@ TEST_F(ServerOperationTest, buildServerOp) { TEST_F(ServerOperationTest, taskEndpoint) { // Setup environment for TaskManager - if (!connector::hasConnectorFactory( - connector::hive::HiveConnectorFactory::kHiveConnectorName)) { - connector::registerConnectorFactory( - std::make_shared()); - } - auto hiveConnector = - connector::getConnectorFactory( - connector::hive::HiveConnectorFactory::kHiveConnectorName) - ->newConnector( - "test-hive", - std::make_shared( - std::unordered_map())); + connector::hive::HiveConnectorFactory factory; + auto hiveConnector = factory.newConnector( + "test-hive", + std::make_shared( + std::unordered_map())); connector::registerConnector(hiveConnector); const auto driverExecutor = std::make_shared( diff --git a/presto-native-execution/presto_cpp/main/tests/TaskManagerTest.cpp b/presto-native-execution/presto_cpp/main/tests/TaskManagerTest.cpp index 02f84d83a77ed..1b8c33ff04d92 100644 --- a/presto-native-execution/presto_cpp/main/tests/TaskManagerTest.cpp +++ b/presto-native-execution/presto_cpp/main/tests/TaskManagerTest.cpp @@ -195,11 +195,6 @@ class TaskManagerTest : public exec::test::OperatorTestBase, static void SetUpTestCase() { OperatorTestBase::SetUpTestCase(); filesystems::registerLocalFileSystem(); - if (!connector::hasConnectorFactory( - connector::hive::HiveConnectorFactory::kHiveConnectorName)) { - connector::registerConnectorFactory( - std::make_shared()); - } test::setupMutableSystemConfig(); SystemConfig::instance()->setValue( std::string(SystemConfig::kMemoryArbitratorKind), "SHARED"); @@ -233,13 +228,11 @@ class TaskManagerTest : public exec::test::OperatorTestBase, registerPrestoToVeloxConnector(std::make_unique( connector::hive::HiveConnectorFactory::kHiveConnectorName)); - auto hiveConnector = - connector::getConnectorFactory( - connector::hive::HiveConnectorFactory::kHiveConnectorName) - ->newConnector( - kHiveConnectorId, - std::make_shared( - std::unordered_map())); + connector::hive::HiveConnectorFactory factory; + auto hiveConnector = factory.newConnector( + kHiveConnectorId, + std::make_shared( + std::unordered_map())); connector::registerConnector(hiveConnector); rowType_ = ROW({"c0", "c1"}, {INTEGER(), VARCHAR()}); diff --git a/presto-native-execution/velox b/presto-native-execution/velox index 426516c093983..904378eb6349c 160000 --- a/presto-native-execution/velox +++ b/presto-native-execution/velox @@ -1 +1 @@ -Subproject commit 426516c093983b7b5230292dba707191c10c7bbf +Subproject commit 904378eb6349c2665f3aa324dd6cd5a13840baed