diff --git a/presto-native-execution/presto_cpp/main/functions/remote/PrestoRestFunctionRegistration.cpp b/presto-native-execution/presto_cpp/main/functions/remote/PrestoRestFunctionRegistration.cpp index bddde97e28e13..2d6c6ed7140d5 100644 --- a/presto-native-execution/presto_cpp/main/functions/remote/PrestoRestFunctionRegistration.cpp +++ b/presto-native-execution/presto_cpp/main/functions/remote/PrestoRestFunctionRegistration.cpp @@ -25,15 +25,17 @@ using facebook::velox::functions::remote::PageFormat; namespace facebook::presto::functions::remote::rest { -namespace { -// Returns the serialization/deserialization format used by the remote function -// server. The format is determined by the system configuration value -// "remoteFunctionServerSerde". Supported formats: -// - "presto_page": Uses Presto page format. -// - "spark_unsafe_row": Uses Spark unsafe row format. -// @return PageFormat enum value corresponding to the configured serde format. -// @throws VeloxException if the configured format is unknown. -PageFormat getSerdeFormat() { + +PrestoRestFunctionRegistration::PrestoRestFunctionRegistration() + : kRemoteFunctionServerRestURL_( + SystemConfig::instance()->remoteFunctionServerRestURL()) {} + +PrestoRestFunctionRegistration& PrestoRestFunctionRegistration::getInstance() { + static PrestoRestFunctionRegistration instance; + return instance; +} + +PageFormat PrestoRestFunctionRegistration::getSerdeFormat() { static const auto serdeFormat = SystemConfig::instance()->remoteFunctionServerSerde(); if (serdeFormat == "presto_page") { @@ -46,17 +48,13 @@ PageFormat getSerdeFormat() { } } -// Encodes a string for safe inclusion in a URL by escaping non-alphanumeric -// characters using percent-encoding. Alphanumeric characters and '-', '_', '.', -// '~' are left unchanged. All other characters are replaced with '%' followed -// by their two-digit hexadecimal value. -// @param value The input string to encode. -// @return The URL-encoded string. -std::string urlEncode(const std::string& value) { +std::string PrestoRestFunctionRegistration::urlEncode( + const std::string& value) { return boost::urls::encode(value, boost::urls::unreserved_chars); } -std::string getFunctionName(const protocol::SqlFunctionId& functionId) { +std::string PrestoRestFunctionRegistration::getFunctionName( + const protocol::SqlFunctionId& functionId) { // Example: "namespace.schema.function;TYPE;TYPE". const auto nameEnd = functionId.find(';'); // Assuming the possibility of missing ';' if there are no function arguments. @@ -64,13 +62,8 @@ std::string getFunctionName(const protocol::SqlFunctionId& functionId) { : functionId; } -// Constructs a Velox function signature from a Presto function signature. This -// function translates type variable constraints, integer variable constraints, -// return type, argument types, and variable arity from the Presto signature to -// the corresponding Velox signature builder. -// @param prestoSignature The Presto function signature to convert. -// @return A pointer to the constructed Velox function signature. -velox::exec::FunctionSignaturePtr buildVeloxSignatureFromPrestoSignature( +velox::exec::FunctionSignaturePtr +PrestoRestFunctionRegistration::buildVeloxSignatureFromPrestoSignature( const protocol::Signature& prestoSignature) { velox::exec::FunctionSignatureBuilder signatureBuilder; @@ -93,48 +86,48 @@ velox::exec::FunctionSignaturePtr buildVeloxSignatureFromPrestoSignature( return signatureBuilder.build(); } -} // namespace +std::string PrestoRestFunctionRegistration::getRemoteFunctionServerUrl( + const protocol::RestFunctionHandle& restFunctionHandle) const { + if (restFunctionHandle.executionEndpoint && + !restFunctionHandle.executionEndpoint->empty()) { + return *restFunctionHandle.executionEndpoint; + } + return kRemoteFunctionServerRestURL_; +} -void registerRestRemoteFunction( +void PrestoRestFunctionRegistration::registerFunction( const protocol::RestFunctionHandle& restFunctionHandle) { - static std::mutex registrationMutex; - static std::unordered_map registeredFunctionHandles; - static std::unordered_map - restClient; - static const std::string remoteFunctionServerRestURL = - SystemConfig::instance()->remoteFunctionServerRestURL(); - const std::string functionId = restFunctionHandle.functionId; + const std::string remoteFunctionServerRestURL = + getRemoteFunctionServerUrl(restFunctionHandle); json functionHandleJson; to_json(functionHandleJson, restFunctionHandle); functionHandleJson["url"] = remoteFunctionServerRestURL; const std::string serializedFunctionHandle = functionHandleJson.dump(); - // Check if already registered (read-only, no lock needed for initial check) { - std::lock_guard lock(registrationMutex); - auto it = registeredFunctionHandles.find(functionId); - if (it != registeredFunctionHandles.end() && + std::lock_guard lock(registrationMutex_); + auto it = registeredFunctionHandles_.find(functionId); + if (it != registeredFunctionHandles_.end() && it->second == serializedFunctionHandle) { return; } } // Get or create shared RestRemoteClient for this server URL - functions::rest::RestRemoteClientPtr remoteClient; + RestRemoteClientPtr remoteClient; { - std::lock_guard lock(registrationMutex); - auto clientIt = restClient.find(remoteFunctionServerRestURL); - if (clientIt == restClient.end()) { - restClient[remoteFunctionServerRestURL] = - std::make_shared( - remoteFunctionServerRestURL); + std::lock_guard lock(registrationMutex_); + auto clientIt = restClients_.find(remoteFunctionServerRestURL); + if (clientIt == restClients_.end()) { + restClients_[remoteFunctionServerRestURL] = + std::make_shared(remoteFunctionServerRestURL); } - remoteClient = restClient[remoteFunctionServerRestURL]; + remoteClient = restClients_[remoteFunctionServerRestURL]; } - functions::rest::VeloxRemoteFunctionMetadata metadata; + VeloxRemoteFunctionMetadata metadata; // Extract function name parts using the utility function const std::string functionName = @@ -158,7 +151,7 @@ void registerRestRemoteFunction( std::vector veloxSignatures = { veloxSignature}; - functions::rest::registerVeloxRemoteFunction( + registerVeloxRemoteFunction( getFunctionName(restFunctionHandle.functionId), veloxSignatures, metadata, @@ -166,8 +159,8 @@ void registerRestRemoteFunction( // Update registration map { - std::lock_guard lock(registrationMutex); - registeredFunctionHandles[functionId] = serializedFunctionHandle; + std::lock_guard lock(registrationMutex_); + registeredFunctionHandles_[functionId] = serializedFunctionHandle; } } } // namespace facebook::presto::functions::remote::rest diff --git a/presto-native-execution/presto_cpp/main/functions/remote/PrestoRestFunctionRegistration.h b/presto-native-execution/presto_cpp/main/functions/remote/PrestoRestFunctionRegistration.h index fcaf3d086f184..57f36fdc0e64a 100644 --- a/presto-native-execution/presto_cpp/main/functions/remote/PrestoRestFunctionRegistration.h +++ b/presto-native-execution/presto_cpp/main/functions/remote/PrestoRestFunctionRegistration.h @@ -14,10 +14,102 @@ #pragma once +#include +#include +#include +#include + +#include "presto_cpp/main/functions/remote/RestRemoteFunction.h" +#include "presto_cpp/main/functions/remote/client/RestRemoteClient.h" #include "presto_cpp/presto_protocol/presto_protocol.h" +#include "velox/expression/FunctionSignature.h" namespace facebook::presto::functions::remote::rest { -void registerRestRemoteFunction( - const protocol::RestFunctionHandle& restFunctionHandle); +/// Manages registration of REST-based remote functions in Velox. +/// This class provides a thread-safe singleton interface for registering +/// remote functions that are accessed via REST API endpoints. +class PrestoRestFunctionRegistration { + public: + /// Returns the singleton instance of the registration manager. + /// @return Reference to the singleton instance. + static PrestoRestFunctionRegistration& getInstance(); + + /// Registers a REST remote function with Velox. + /// This method is thread-safe and handles duplicate registrations. + /// @param restFunctionHandle The Presto REST function handle containing + /// function metadata, signature, and location information. + void registerFunction(const protocol::RestFunctionHandle& restFunctionHandle); + + // Delete copy constructor and assignment operator + PrestoRestFunctionRegistration(const PrestoRestFunctionRegistration&) = + delete; + PrestoRestFunctionRegistration& operator=( + const PrestoRestFunctionRegistration&) = delete; + + private: + // Private constructor for singleton pattern. + PrestoRestFunctionRegistration(); + + // Resolves the remote function server URL from the function handle. + // @param restFunctionHandle The Presto REST function handle that may + // contain an execution endpoint. + // @return The resolved remote function server URL. + std::string getRemoteFunctionServerUrl( + const protocol::RestFunctionHandle& restFunctionHandle) const; + + // Returns the serialization/deserialization format used by the remote + // function server. + // @return PageFormat enum value corresponding to the configured serde + // format. + static velox::functions::remote::PageFormat getSerdeFormat(); + + // Encodes a string for safe inclusion in a URL. + // @param value The input string to encode. + // @return The URL-encoded string. + static std::string urlEncode(const std::string& value); + + // Extracts the function name from a function ID. + // @param functionId The SQL function ID. + // @return The function name without type parameters. + static std::string getFunctionName(const protocol::SqlFunctionId& functionId); + + // Constructs a Velox function signature from a Presto function signature. + // @param prestoSignature The Presto function signature to convert. + // @return A pointer to the constructed Velox function signature. + static velox::exec::FunctionSignaturePtr + buildVeloxSignatureFromPrestoSignature( + const protocol::Signature& prestoSignature); + + // Mutex for thread-safe registration operations. + std::mutex registrationMutex_; + + // Map of registered function IDs to their serialized handles. + std::unordered_map registeredFunctionHandles_; + + // Map of REST server URLs to their corresponding client instances. + std::unordered_map restClients_; + + // The base URL for the remote function server REST API. + const std::string kRemoteFunctionServerRestURL_; + + VELOX_FRIEND_TEST( + PrestoRestFunctionRegistrationTest, + getRemoteFunctionServerUrlWithExecutionEndpoint); + VELOX_FRIEND_TEST( + PrestoRestFunctionRegistrationTest, + getRemoteFunctionServerUrlWithEmptyExecutionEndpoint); + VELOX_FRIEND_TEST( + PrestoRestFunctionRegistrationTest, + getRemoteFunctionServerUrlWithoutExecutionEndpoint); + VELOX_FRIEND_TEST( + PrestoRestFunctionRegistrationTest, + getRemoteFunctionServerUrlConsistency); + VELOX_FRIEND_TEST( + PrestoRestFunctionRegistrationTest, + getRemoteFunctionServerUrlWithDifferentProtocols); + VELOX_FRIEND_TEST( + PrestoRestFunctionRegistrationTest, + getRemoteFunctionServerUrlWithComplexUrls); +}; } // namespace facebook::presto::functions::remote::rest diff --git a/presto-native-execution/presto_cpp/main/functions/remote/RestRemoteFunction.cpp b/presto-native-execution/presto_cpp/main/functions/remote/RestRemoteFunction.cpp index b8f1c3cade7e5..98610b325f02c 100644 --- a/presto-native-execution/presto_cpp/main/functions/remote/RestRemoteFunction.cpp +++ b/presto-native-execution/presto_cpp/main/functions/remote/RestRemoteFunction.cpp @@ -17,7 +17,7 @@ #include "velox/functions/remote/client/RemoteVectorFunction.h" using namespace facebook::velox; -namespace facebook::presto::functions::rest { +namespace facebook::presto::functions::remote::rest { namespace { class RestRemoteFunction : public velox::functions::RemoteVectorFunction { @@ -100,4 +100,4 @@ void registerVeloxRemoteFunction( overwrite); } -} // namespace facebook::presto::functions::rest +} // namespace facebook::presto::functions::remote::rest diff --git a/presto-native-execution/presto_cpp/main/functions/remote/RestRemoteFunction.h b/presto-native-execution/presto_cpp/main/functions/remote/RestRemoteFunction.h index 408fc6700cc6f..a31b48295f80f 100644 --- a/presto-native-execution/presto_cpp/main/functions/remote/RestRemoteFunction.h +++ b/presto-native-execution/presto_cpp/main/functions/remote/RestRemoteFunction.h @@ -17,7 +17,7 @@ #include "presto_cpp/main/functions/remote/client/RestRemoteClient.h" #include "velox/functions/remote/client/RemoteVectorFunction.h" -namespace facebook::presto::functions::rest { +namespace facebook::presto::functions::remote::rest { struct VeloxRemoteFunctionMetadata : public velox::functions::RemoteVectorFunctionMetadata { @@ -32,4 +32,4 @@ void registerVeloxRemoteFunction( RestRemoteClientPtr restClient, bool overwrite = true); -} // namespace facebook::presto::functions::rest +} // namespace facebook::presto::functions::remote::rest diff --git a/presto-native-execution/presto_cpp/main/functions/remote/client/RestRemoteClient.cpp b/presto-native-execution/presto_cpp/main/functions/remote/client/RestRemoteClient.cpp index 0d8b94710fe83..b39483fb20d4b 100644 --- a/presto-native-execution/presto_cpp/main/functions/remote/client/RestRemoteClient.cpp +++ b/presto-native-execution/presto_cpp/main/functions/remote/client/RestRemoteClient.cpp @@ -23,7 +23,7 @@ using namespace facebook::velox; -namespace facebook::presto::functions::rest { +namespace facebook::presto::functions::remote::rest { namespace { inline std::string getContentType(velox::functions::remote::PageFormat fmt) { return fmt == velox::functions::remote::PageFormat::SPARK_UNSAFE_ROW @@ -110,4 +110,4 @@ std::unique_ptr RestRemoteClient::invokeFunction( return nullptr; } -} // namespace facebook::presto::functions::rest +} // namespace facebook::presto::functions::remote::rest diff --git a/presto-native-execution/presto_cpp/main/functions/remote/client/RestRemoteClient.h b/presto-native-execution/presto_cpp/main/functions/remote/client/RestRemoteClient.h index 9fe89944c7202..e9af5ac432ed7 100644 --- a/presto-native-execution/presto_cpp/main/functions/remote/client/RestRemoteClient.h +++ b/presto-native-execution/presto_cpp/main/functions/remote/client/RestRemoteClient.h @@ -20,7 +20,7 @@ #include "presto_cpp/main/http/HttpClient.h" #include "velox/functions/remote/if/gen-cpp2/RemoteFunction_types.h" -namespace facebook::presto::functions::rest { +namespace facebook::presto::functions::remote::rest { class RestRemoteClient { public: @@ -50,4 +50,4 @@ class RestRemoteClient { using RestRemoteClientPtr = std::shared_ptr; -} // namespace facebook::presto::functions::rest +} // namespace facebook::presto::functions::remote::rest diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/CMakeLists.txt b/presto-native-execution/presto_cpp/main/functions/remote/tests/CMakeLists.txt index 59baa1eb24fef..b13801f7daa74 100644 --- a/presto-native-execution/presto_cpp/main/functions/remote/tests/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/CMakeLists.txt @@ -28,3 +28,16 @@ target_link_libraries( GTest::gtest GTest::gtest_main ) + +add_executable(presto_rest_function_registration_test PrestoRestFunctionRegistrationTest.cpp) + +add_test(presto_rest_function_registration_test presto_rest_function_registration_test) + +target_link_libraries( + presto_rest_function_registration_test + presto_to_velox_remote_functions + presto_functions_remote + presto_protocol + GTest::gtest + GTest::gtest_main +) diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/PrestoRestFunctionRegistrationTest.cpp b/presto-native-execution/presto_cpp/main/functions/remote/tests/PrestoRestFunctionRegistrationTest.cpp new file mode 100644 index 0000000000000..3d9c11a85f620 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/PrestoRestFunctionRegistrationTest.cpp @@ -0,0 +1,346 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "presto_cpp/main/functions/remote/PrestoRestFunctionRegistration.h" +#include +#include "presto_cpp/main/common/Configs.h" +#include "presto_cpp/presto_protocol/presto_protocol.h" + +using namespace facebook::presto; + +namespace facebook::presto::functions::remote::rest { + +class PrestoRestFunctionRegistrationTest : public ::testing::Test { + protected: + void SetUp() override { + auto systemConfig = SystemConfig::instance(); + std::unordered_map config{ + {std::string(SystemConfig::kRemoteFunctionServerSerde), + std::string("presto_page")}, + {std::string(SystemConfig::kRemoteFunctionServerRestURL), + std::string("http://localhost:8080")}}; + systemConfig->initialize( + std::make_unique(std::move(config), true)); + } + + std::shared_ptr parseRestFunctionHandle( + const std::string& jsonStr) { + const json j = json::parse(jsonStr); + return j; + } +}; + +TEST_F( + PrestoRestFunctionRegistrationTest, + getRemoteFunctionServerUrlWithExecutionEndpoint) { + // Test when executionEndpoint is provided + const std::string jsonStr = R"JSON( + { + "@type": "RestFunctionHandle", + "functionId": "remote.testSchema.testFunction;BIGINT", + "version": "v1", + "signature": { + "name": "testFunction", + "kind": "SCALAR", + "returnType": "bigint", + "argumentTypes": ["bigint"], + "typeVariableConstraints": [], + "longVariableConstraints": [], + "variableArity": false + }, + "executionEndpoint": "http://custom-server:8080" + } + )JSON"; + + auto handle = parseRestFunctionHandle(jsonStr); + ASSERT_NE(handle, nullptr); + EXPECT_EQ(*handle->executionEndpoint, "http://custom-server:8080"); + + EXPECT_EQ(handle->signature.name, "testFunction"); + EXPECT_EQ(handle->signature.kind, protocol::FunctionKind::SCALAR); + EXPECT_EQ(handle->signature.returnType, "bigint"); + ASSERT_EQ(handle->signature.argumentTypes.size(), 1); + EXPECT_EQ(handle->signature.argumentTypes[0], "bigint"); + EXPECT_FALSE(handle->signature.variableArity); + + std::string result = + PrestoRestFunctionRegistration::getInstance().getRemoteFunctionServerUrl( + *handle); + + EXPECT_EQ(result, "http://custom-server:8080"); +} + +TEST_F( + PrestoRestFunctionRegistrationTest, + getRemoteFunctionServerUrlWithEmptyExecutionEndpoint) { + // Test when executionEndpoint is provided but empty + const std::string jsonStr = R"JSON( + { + "@type": "RestFunctionHandle", + "functionId": "remote.testSchema.testFunction;BIGINT", + "version": "v1", + "signature": { + "name": "testFunction", + "kind": "SCALAR", + "returnType": "bigint", + "argumentTypes": ["bigint"], + "typeVariableConstraints": [], + "longVariableConstraints": [], + "variableArity": false + }, + "executionEndpoint": "" + } + )JSON"; + + auto handle = parseRestFunctionHandle(jsonStr); + ASSERT_NE(handle, nullptr); + EXPECT_EQ(*handle->executionEndpoint, ""); + + EXPECT_EQ(handle->signature.name, "testFunction"); + EXPECT_EQ(handle->signature.kind, protocol::FunctionKind::SCALAR); + EXPECT_EQ(handle->signature.returnType, "bigint"); + ASSERT_EQ(handle->signature.argumentTypes.size(), 1); + EXPECT_EQ(handle->signature.argumentTypes[0], "bigint"); + EXPECT_FALSE(handle->signature.variableArity); + + auto& instance = PrestoRestFunctionRegistration::getInstance(); + std::string result = instance.getRemoteFunctionServerUrl(*handle); + + // Should fall back to default URL (kRemoteFunctionServerRestURL_) + EXPECT_EQ(result, "http://localhost:8080"); +} + +TEST_F( + PrestoRestFunctionRegistrationTest, + getRemoteFunctionServerUrlWithoutExecutionEndpoint) { + // Test when executionEndpoint is not provided (nullopt) + const std::string jsonStr = R"JSON( + { + "@type": "RestFunctionHandle", + "functionId": "remote.testSchema.testFunction;BIGINT", + "version": "v1", + "signature": { + "name": "testFunction", + "kind": "SCALAR", + "returnType": "bigint", + "argumentTypes": ["bigint"], + "typeVariableConstraints": [], + "longVariableConstraints": [], + "variableArity": false + } + } + )JSON"; + + auto handle = parseRestFunctionHandle(jsonStr); + ASSERT_NE(handle, nullptr); + + EXPECT_EQ(handle->signature.name, "testFunction"); + EXPECT_EQ(handle->signature.kind, protocol::FunctionKind::SCALAR); + EXPECT_EQ(handle->signature.returnType, "bigint"); + ASSERT_EQ(handle->signature.argumentTypes.size(), 1); + EXPECT_EQ(handle->signature.argumentTypes[0], "bigint"); + EXPECT_FALSE(handle->signature.variableArity); + + auto& instance = PrestoRestFunctionRegistration::getInstance(); + std::string result = instance.getRemoteFunctionServerUrl(*handle); + + // Should fall back to default URL (kRemoteFunctionServerRestURL_) + EXPECT_EQ(result, "http://localhost:8080"); +} + +TEST_F( + PrestoRestFunctionRegistrationTest, + getRemoteFunctionServerUrlConsistency) { + // Test that the same input produces the same output + const std::string jsonStr1 = R"JSON( + { + "@type": "RestFunctionHandle", + "functionId": "remote.testSchema.function1;BIGINT", + "version": "v1", + "signature": { + "name": "function1", + "kind": "SCALAR", + "returnType": "bigint", + "argumentTypes": ["bigint"], + "typeVariableConstraints": [], + "longVariableConstraints": [], + "variableArity": false + }, + "executionEndpoint": "http://server1:9090" + } + )JSON"; + + const std::string jsonStr2 = R"JSON( + { + "@type": "RestFunctionHandle", + "functionId": "remote.testSchema.function2;BIGINT", + "version": "v1", + "signature": { + "name": "function2", + "kind": "SCALAR", + "returnType": "bigint", + "argumentTypes": ["bigint"], + "typeVariableConstraints": [], + "longVariableConstraints": [], + "variableArity": false + }, + "executionEndpoint": "http://server2:9091" + } + )JSON"; + + const std::string jsonStr3 = R"JSON( + { + "@type": "RestFunctionHandle", + "functionId": "remote.testSchema.function3;BIGINT", + "version": "v1", + "signature": { + "name": "function3", + "kind": "SCALAR", + "returnType": "bigint", + "argumentTypes": ["bigint"], + "typeVariableConstraints": [], + "longVariableConstraints": [], + "variableArity": false + }, + "executionEndpoint": "http://server1:9090" + } + )JSON"; + + auto handle1 = parseRestFunctionHandle(jsonStr1); + auto handle2 = parseRestFunctionHandle(jsonStr2); + auto handle3 = parseRestFunctionHandle(jsonStr3); + + EXPECT_EQ(handle1->signature.name, "function1"); + EXPECT_EQ(handle1->signature.kind, protocol::FunctionKind::SCALAR); + + EXPECT_EQ(handle2->signature.name, "function2"); + EXPECT_EQ(handle2->signature.kind, protocol::FunctionKind::SCALAR); + + EXPECT_EQ(handle3->signature.name, "function3"); + EXPECT_EQ(handle3->signature.kind, protocol::FunctionKind::SCALAR); + + auto& instance = PrestoRestFunctionRegistration::getInstance(); + std::string result1 = instance.getRemoteFunctionServerUrl(*handle1); + std::string result2 = instance.getRemoteFunctionServerUrl(*handle2); + std::string result3 = instance.getRemoteFunctionServerUrl(*handle3); + + EXPECT_EQ(result1, "http://server1:9090"); + EXPECT_EQ(result2, "http://server2:9091"); + EXPECT_EQ(result3, "http://server1:9090"); + EXPECT_EQ(result1, result3); + EXPECT_NE(result1, result2); +} + +TEST_F( + PrestoRestFunctionRegistrationTest, + getRemoteFunctionServerUrlWithDifferentProtocols) { + // Test with different URL protocols + const std::string httpJsonStr = R"JSON( + { + "@type": "RestFunctionHandle", + "functionId": "remote.testSchema.httpFunction;BIGINT", + "version": "v1", + "signature": { + "name": "httpFunction", + "kind": "SCALAR", + "returnType": "bigint", + "argumentTypes": ["bigint"], + "typeVariableConstraints": [], + "longVariableConstraints": [], + "variableArity": false + }, + "executionEndpoint": "http://server:8080" + } + )JSON"; + + const std::string httpsJsonStr = R"JSON( + { + "@type": "RestFunctionHandle", + "functionId": "remote.testSchema.httpsFunction;BIGINT", + "version": "v1", + "signature": { + "name": "httpsFunction", + "kind": "SCALAR", + "returnType": "bigint", + "argumentTypes": ["bigint"], + "typeVariableConstraints": [], + "longVariableConstraints": [], + "variableArity": false + }, + "executionEndpoint": "https://secure-server:8443" + } + )JSON"; + + auto httpHandle = parseRestFunctionHandle(httpJsonStr); + auto httpsHandle = parseRestFunctionHandle(httpsJsonStr); + + EXPECT_EQ(httpHandle->signature.name, "httpFunction"); + EXPECT_EQ(httpHandle->signature.kind, protocol::FunctionKind::SCALAR); + EXPECT_EQ(httpHandle->signature.returnType, "bigint"); + + EXPECT_EQ(httpsHandle->signature.name, "httpsFunction"); + EXPECT_EQ(httpsHandle->signature.kind, protocol::FunctionKind::SCALAR); + EXPECT_EQ(httpsHandle->signature.returnType, "bigint"); + + std::string httpResult = + PrestoRestFunctionRegistration::getInstance().getRemoteFunctionServerUrl( + *httpHandle); + std::string httpsResult = + PrestoRestFunctionRegistration::getInstance().getRemoteFunctionServerUrl( + *httpsHandle); + + EXPECT_EQ(httpResult, "http://server:8080"); + EXPECT_EQ(httpsResult, "https://secure-server:8443"); +} + +TEST_F( + PrestoRestFunctionRegistrationTest, + getRemoteFunctionServerUrlWithComplexUrls) { + // Test with URLs containing paths and query parameters + const std::string jsonStr = R"JSON( + { + "@type": "RestFunctionHandle", + "functionId": "remote.testSchema.complexFunction;BIGINT", + "version": "v1", + "signature": { + "name": "complexFunction", + "kind": "SCALAR", + "returnType": "bigint", + "argumentTypes": ["bigint"], + "typeVariableConstraints": [], + "longVariableConstraints": [], + "variableArity": false + }, + "executionEndpoint": "http://server:8080/api/v1/functions?param=value" + } + )JSON"; + + auto handle = parseRestFunctionHandle(jsonStr); + ASSERT_NE(handle, nullptr); + + EXPECT_EQ(handle->signature.name, "complexFunction"); + EXPECT_EQ(handle->signature.kind, protocol::FunctionKind::SCALAR); + EXPECT_EQ(handle->signature.returnType, "bigint"); + ASSERT_EQ(handle->signature.argumentTypes.size(), 1); + EXPECT_EQ(handle->signature.argumentTypes[0], "bigint"); + EXPECT_FALSE(handle->signature.variableArity); + + std::string result = + PrestoRestFunctionRegistration::getInstance().getRemoteFunctionServerUrl( + *handle); + + EXPECT_EQ(result, "http://server:8080/api/v1/functions?param=value"); +} + +} // namespace facebook::presto::functions::remote::rest diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/RemoteFunctionRestTest.cpp b/presto-native-execution/presto_cpp/main/functions/remote/tests/RemoteFunctionRestTest.cpp index ddf233b7647d9..63a8af23f1e0b 100644 --- a/presto-native-execution/presto_cpp/main/functions/remote/tests/RemoteFunctionRestTest.cpp +++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/RemoteFunctionRestTest.cpp @@ -35,7 +35,7 @@ using ::facebook::velox::test::assertEqualVectors; using namespace facebook::velox; -namespace facebook::presto::functions::rest::test { +namespace facebook::presto::functions::remote::rest::test { namespace { class RemoteFunctionRestTest @@ -290,7 +290,7 @@ VELOX_INSTANTIATE_TEST_SUITE_P( velox::functions::remote::PageFormat::SPARK_UNSAFE_ROW)); } // namespace -} // namespace facebook::presto::functions::rest::test +} // namespace facebook::presto::functions::remote::rest::test int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestHandler.h b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestHandler.h index b33dac334f861..86f9a59a4a1d6 100644 --- a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestHandler.h +++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestHandler.h @@ -17,7 +17,7 @@ #include "velox/vector/VectorStream.h" -namespace facebook::presto::functions::rest { +namespace facebook::presto::functions::remote::rest::test { class RemoteFunctionRestHandler { public: @@ -74,4 +74,4 @@ class RemoteFunctionRestHandler { std::string& errorMessage) = 0; }; -} // namespace facebook::presto::functions::rest +} // namespace facebook::presto::functions::remote::rest::test diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestService.cpp b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestService.cpp index 2c944687f1f3f..0c827030ffa59 100644 --- a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestService.cpp +++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestService.cpp @@ -22,7 +22,7 @@ #include "velox/serializers/UnsafeRowSerializer.h" using namespace facebook::velox; -namespace facebook::presto::functions::rest { +namespace facebook::presto::functions::remote::rest::test { RestSession::RestSession(boost::asio::ip::tcp::socket socket) : socket_(std::move(socket)), @@ -311,4 +311,4 @@ void RestListener::onAccept( doAccept(); } -} // namespace facebook::presto::functions::rest +} // namespace facebook::presto::functions::remote::rest::test diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestService.h b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestService.h index e24d889dd64ff..3f888f56322a7 100644 --- a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestService.h +++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestService.h @@ -21,7 +21,7 @@ #include "presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestHandler.h" #include "presto_cpp/main/functions/remote/tests/server/RestFunctionRegistry.h" -namespace facebook::presto::functions::rest { +namespace facebook::presto::functions::remote::rest::test { /// @brief Manages an individual HTTP session. /// Handles reading HTTP requests, processing them, and sending responses. @@ -122,4 +122,4 @@ class RestListener : public std::enable_shared_from_this { boost::asio::ip::tcp::acceptor acceptor_; }; -} // namespace facebook::presto::functions::rest +} // namespace facebook::presto::functions::remote::rest::test diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RestFunctionRegistry.cpp b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RestFunctionRegistry.cpp index 4cd56c209cd84..6bd6551e9a281 100644 --- a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RestFunctionRegistry.cpp +++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RestFunctionRegistry.cpp @@ -16,7 +16,7 @@ #include -namespace facebook::presto::functions::rest { +namespace facebook::presto::functions::remote::rest::test { RestFunctionRegistry& RestFunctionRegistry::getInstance() { static RestFunctionRegistry instance; @@ -66,4 +66,4 @@ void RestFunctionRegistry::clear() { functionHandlers_.clear(); } -} // namespace facebook::presto::functions::rest +} // namespace facebook::presto::functions::remote::rest::test diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RestFunctionRegistry.h b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RestFunctionRegistry.h index 4284e3c086703..71c17bb655a71 100644 --- a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RestFunctionRegistry.h +++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RestFunctionRegistry.h @@ -20,7 +20,7 @@ #include "presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestHandler.h" -namespace facebook::presto::functions::rest { +namespace facebook::presto::functions::remote::rest::test { /// @brief Registry for remote function REST handlers. /// Provides a centralized location for registering and looking up function @@ -74,4 +74,4 @@ class RestFunctionRegistry { functionHandlers_; }; -} // namespace facebook::presto::functions::rest +} // namespace facebook::presto::functions::remote::rest::test diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteDoubleDivHandler.h b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteDoubleDivHandler.h index b55c4641ab460..e8a23322be0bb 100644 --- a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteDoubleDivHandler.h +++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteDoubleDivHandler.h @@ -16,7 +16,7 @@ #include "presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestHandler.h" -namespace facebook::presto::functions::rest { +namespace facebook::presto::functions::remote::rest::test { class RemoteDoubleDivHandler : public RemoteFunctionRestHandler { public: @@ -58,4 +58,4 @@ class RemoteDoubleDivHandler : public RemoteFunctionRestHandler { } }; -} // namespace facebook::presto::functions::rest +} // namespace facebook::presto::functions::remote::rest::test diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteFibonacciHandler.h b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteFibonacciHandler.h index 8de2bd9694041..1dfbb510e86fc 100644 --- a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteFibonacciHandler.h +++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteFibonacciHandler.h @@ -17,7 +17,7 @@ #include #include "presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestHandler.h" -namespace facebook::presto::functions::rest { +namespace facebook::presto::functions::remote::rest::test { class RemoteFibonacciHandler : public RemoteFunctionRestHandler { public: @@ -51,4 +51,4 @@ class RemoteFibonacciHandler : public RemoteFunctionRestHandler { } }; -} // namespace facebook::presto::functions::rest +} // namespace facebook::presto::functions::remote::rest::test diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteInverseCdfHandler.h b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteInverseCdfHandler.h index bd65d34de1298..a4498b455df89 100644 --- a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteInverseCdfHandler.h +++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteInverseCdfHandler.h @@ -17,7 +17,7 @@ #include #include "presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestHandler.h" -namespace facebook::presto::functions::rest { +namespace facebook::presto::functions::remote::rest::test { namespace { inline double inverse_chi_squared_cdf(double p, double nu) { if (p <= 0.0 || p >= 1.0) { @@ -73,4 +73,4 @@ class RemoteInverseCdfHandler : public RemoteFunctionRestHandler { } }; -} // namespace facebook::presto::functions::rest +} // namespace facebook::presto::functions::remote::rest::test diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteRemoveCharHandler.h b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteRemoveCharHandler.h index b798189258128..04e381976d7cd 100644 --- a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteRemoveCharHandler.h +++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteRemoveCharHandler.h @@ -16,7 +16,7 @@ #include "presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestHandler.h" -namespace facebook::presto::functions::rest { +namespace facebook::presto::functions::remote::rest::test { class RemoteRemoveCharHandler final : public RemoteFunctionRestHandler { public: @@ -61,4 +61,4 @@ class RemoteRemoveCharHandler final : public RemoteFunctionRestHandler { } }; -} // namespace facebook::presto::functions::rest +} // namespace facebook::presto::functions::remote::rest::test diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteStrLenHandler.h b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteStrLenHandler.h index 980a50937048c..84cd061dcd52b 100644 --- a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteStrLenHandler.h +++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteStrLenHandler.h @@ -16,7 +16,7 @@ #include "presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestHandler.h" -namespace facebook::presto::functions::rest { +namespace facebook::presto::functions::remote::rest::test { class RemoteStrLenHandler : public RemoteFunctionRestHandler { public: RemoteStrLenHandler() = default; @@ -49,4 +49,4 @@ class RemoteStrLenHandler : public RemoteFunctionRestHandler { } }; -} // namespace facebook::presto::functions::rest +} // namespace facebook::presto::functions::remote::rest::test diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp index 27b6fc810595f..d63509c284a4e 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp @@ -538,7 +538,8 @@ TypedExprPtr VeloxExprConverter::toVeloxExpr( auto args = toVeloxExpr(pexpr.arguments); auto returnType = typeParser_->parse(pexpr.returnType); - functions::remote::rest::registerRestRemoteFunction(*restFunctionHandle); + functions::remote::rest::PrestoRestFunctionRegistration::getInstance() + .registerFunction(*restFunctionHandle); return std::make_shared( returnType, args, getFunctionName(restFunctionHandle->functionId)); }