diff --git a/.gitignore b/.gitignore index a4512f9f794d..bf975bab3ff0 100644 --- a/.gitignore +++ b/.gitignore @@ -66,3 +66,4 @@ presto-native-execution/deps-install # Compiled executables used for docker build /docker/presto-cli-*-executable.jar /docker/presto-server-*.tar.gz +/docker/presto-function-server-executable.jar diff --git a/docker/Dockerfile b/docker/Dockerfile index d44de58d549b..4ce01b8818d7 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -3,6 +3,7 @@ FROM quay.io/centos/centos:stream9 ARG PRESTO_VERSION ARG PRESTO_PKG=presto-server-$PRESTO_VERSION.tar.gz ARG PRESTO_CLI_JAR=presto-cli-$PRESTO_VERSION-executable.jar +ARG PRESTO_REMOTE_SERVER_JAR=presto-function-server-executable.jar ARG JMX_PROMETHEUS_JAVAAGENT_VERSION=0.20.0 ENV PRESTO_HOME="/opt/presto-server" @@ -13,7 +14,8 @@ RUN --mount=type=cache,target=/var/cache/dnf,sharing=locked \ # clean cache jobs && mv /etc/yum/protected.d/systemd.conf /etc/yum/protected.d/systemd.conf.bak -COPY --chmod=755 $PRESTO_CLI_JAR /opt/presto-cli +COPY --chmod=755 $PRESTO_CLI_JAR /opt/presto-cli +COPY --chmod=755 $PRESTO_REMOTE_SERVER_JAR /opt/presto-remote-function-server RUN --mount=type=bind,source=$PRESTO_PKG,target=/$PRESTO_PKG \ # Download Presto and move \ diff --git a/presto-native-execution/.dockerignore b/presto-native-execution/.dockerignore new file mode 100644 index 000000000000..b25c310f8162 --- /dev/null +++ b/presto-native-execution/.dockerignore @@ -0,0 +1,4 @@ +# Ignore build directories +_build/ +cmake-build-debug/ +cmake-build-release/ diff --git a/presto-native-execution/CMakeLists.txt b/presto-native-execution/CMakeLists.txt index d7becdc6188a..fc5d1ddf6269 100644 --- a/presto-native-execution/CMakeLists.txt +++ b/presto-native-execution/CMakeLists.txt @@ -136,7 +136,7 @@ set(Boost_USE_MULTITHREADED TRUE) find_package( Boost 1.66.0 - REQUIRED program_options context filesystem regex thread system date_time atomic + REQUIRED program_options context filesystem regex thread system date_time url atomic ) include_directories(SYSTEM ${Boost_INCLUDE_DIRS}) diff --git a/presto-native-execution/pom.xml b/presto-native-execution/pom.xml index 9eeb4b10250e..34c0fcb10701 100644 --- a/presto-native-execution/pom.xml +++ b/presto-native-execution/pom.xml @@ -503,6 +503,12 @@ presto-cli-*-executable.jar + + ${project.parent.basedir}/presto-function-server/target + + presto-function-server-executable.jar + + ${project.parent.basedir}/presto-server/target @@ -574,7 +580,7 @@ Release presto-native-dependency:latest - -DPRESTO_ENABLE_TESTING=OFF + "-DPRESTO_ENABLE_TESTING=OFF -DPRESTO_ENABLE_REMOTE_FUNCTIONS=ON" 2 ubuntu:22.04 ubuntu diff --git a/presto-native-execution/presto_cpp/main/PrestoServer.cpp b/presto-native-execution/presto_cpp/main/PrestoServer.cpp index 11811b5923ba..9fcf03384215 100644 --- a/presto-native-execution/presto_cpp/main/PrestoServer.cpp +++ b/presto-native-execution/presto_cpp/main/PrestoServer.cpp @@ -1391,11 +1391,12 @@ void PrestoServer::registerRemoteFunctions() { << catalogName << "' catalog."; } else { VELOX_FAIL( - "To register remote functions using a json file path you need to " - "specify the remote server location using '{}', '{}' or '{}'.", + "To register remote functions you need to specify the remote server " + "location using '{}', '{}' or '{}' or {}.", SystemConfig::kRemoteFunctionServerThriftAddress, SystemConfig::kRemoteFunctionServerThriftPort, - SystemConfig::kRemoteFunctionServerThriftUdsPath); + SystemConfig::kRemoteFunctionServerThriftUdsPath, + SystemConfig::kRemoteFunctionServerRestURL); } } #endif diff --git a/presto-native-execution/presto_cpp/main/common/Configs.cpp b/presto-native-execution/presto_cpp/main/common/Configs.cpp index 80df7602f3f0..18cd399fdee3 100644 --- a/presto-native-execution/presto_cpp/main/common/Configs.cpp +++ b/presto-native-execution/presto_cpp/main/common/Configs.cpp @@ -479,6 +479,10 @@ std::string SystemConfig::remoteFunctionServerSerde() const { return optionalProperty(kRemoteFunctionServerSerde).value(); } +std::string SystemConfig::remoteFunctionServerRestURL() const { + return optionalProperty(kRemoteFunctionServerRestURL).value(); +} + int32_t SystemConfig::maxDriversPerTask() const { return optionalProperty(kMaxDriversPerTask).value(); } diff --git a/presto-native-execution/presto_cpp/main/common/Configs.h b/presto-native-execution/presto_cpp/main/common/Configs.h index b3419a69e67e..e7f5d8d5d785 100644 --- a/presto-native-execution/presto_cpp/main/common/Configs.h +++ b/presto-native-execution/presto_cpp/main/common/Configs.h @@ -763,6 +763,10 @@ class SystemConfig : public ConfigBase { static constexpr std::string_view kRemoteFunctionServerThriftUdsPath{ "remote-function-server.thrift.uds-path"}; + /// HTTP URL used by the remote function rest server. + static constexpr std::string_view kRemoteFunctionServerRestURL{ + "remote-function-server.rest.url"}; + /// Path where json files containing signatures for remote functions can be /// found. static constexpr std::string_view @@ -924,6 +928,8 @@ class SystemConfig : public ConfigBase { std::string remoteFunctionServerSerde() const; + std::string remoteFunctionServerRestURL() const; + int32_t maxDriversPerTask() const; int32_t driverMaxSplitPreload() const; diff --git a/presto-native-execution/presto_cpp/main/functions/CMakeLists.txt b/presto-native-execution/presto_cpp/main/functions/CMakeLists.txt index 88aa887dd400..20020ea182e5 100644 --- a/presto-native-execution/presto_cpp/main/functions/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/functions/CMakeLists.txt @@ -11,10 +11,14 @@ # limitations under the License. add_library(presto_function_metadata OBJECT FunctionMetadata.cpp) -target_link_libraries(presto_function_metadata velox_function_registry) +target_link_libraries(presto_function_metadata presto_common velox_function_registry) add_subdirectory(dynamic_registry) +if(PRESTO_ENABLE_REMOTE_FUNCTIONS) + add_subdirectory(remote) +endif() + if(PRESTO_ENABLE_TESTING) add_subdirectory(tests) endif() diff --git a/presto-native-execution/presto_cpp/main/functions/remote/CMakeLists.txt b/presto-native-execution/presto_cpp/main/functions/remote/CMakeLists.txt new file mode 100644 index 000000000000..c2c4f8c7990f --- /dev/null +++ b/presto-native-execution/presto_cpp/main/functions/remote/CMakeLists.txt @@ -0,0 +1,28 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_library(presto_to_velox_remote_functions PrestoRestFunctionRegistration.cpp) +target_link_libraries( + presto_to_velox_remote_functions + presto_functions_remote + velox_type_fbhive + Boost::url +) + +add_library(presto_functions_remote RestRemoteFunction.cpp) +target_link_libraries(presto_functions_remote presto_functions_rest_client velox_functions_remote) + +add_subdirectory(client) + +if(${PRESTO_ENABLE_TESTING}) + add_subdirectory(tests) +endif() diff --git a/presto-native-execution/presto_cpp/main/functions/remote/PrestoRestFunctionRegistration.cpp b/presto-native-execution/presto_cpp/main/functions/remote/PrestoRestFunctionRegistration.cpp new file mode 100644 index 000000000000..bddde97e28e1 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/functions/remote/PrestoRestFunctionRegistration.cpp @@ -0,0 +1,173 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "presto_cpp/main/functions/remote/PrestoRestFunctionRegistration.h" + +#include +#include + +#include "presto_cpp/main/common/Configs.h" +#include "presto_cpp/main/common/Utils.h" +#include "presto_cpp/main/functions/remote/RestRemoteFunction.h" +#include "presto_cpp/main/functions/remote/client/RestRemoteClient.h" + +using facebook::velox::functions::remote::PageFormat; + +namespace facebook::presto::functions::remote::rest { +namespace { +// Returns the serialization/deserialization format used by the remote function +// server. The format is determined by the system configuration value +// "remoteFunctionServerSerde". Supported formats: +// - "presto_page": Uses Presto page format. +// - "spark_unsafe_row": Uses Spark unsafe row format. +// @return PageFormat enum value corresponding to the configured serde format. +// @throws VeloxException if the configured format is unknown. +PageFormat getSerdeFormat() { + static const auto serdeFormat = + SystemConfig::instance()->remoteFunctionServerSerde(); + if (serdeFormat == "presto_page") { + return PageFormat::PRESTO_PAGE; + } else if (serdeFormat == "spark_unsafe_row") { + return PageFormat::SPARK_UNSAFE_ROW; + } else { + VELOX_FAIL( + "Unknown serde name for remote function server: '{}'", serdeFormat); + } +} + +// Encodes a string for safe inclusion in a URL by escaping non-alphanumeric +// characters using percent-encoding. Alphanumeric characters and '-', '_', '.', +// '~' are left unchanged. All other characters are replaced with '%' followed +// by their two-digit hexadecimal value. +// @param value The input string to encode. +// @return The URL-encoded string. +std::string urlEncode(const std::string& value) { + return boost::urls::encode(value, boost::urls::unreserved_chars); +} + +std::string getFunctionName(const protocol::SqlFunctionId& functionId) { + // Example: "namespace.schema.function;TYPE;TYPE". + const auto nameEnd = functionId.find(';'); + // Assuming the possibility of missing ';' if there are no function arguments. + return nameEnd != std::string::npos ? functionId.substr(0, nameEnd) + : functionId; +} + +// Constructs a Velox function signature from a Presto function signature. This +// function translates type variable constraints, integer variable constraints, +// return type, argument types, and variable arity from the Presto signature to +// the corresponding Velox signature builder. +// @param prestoSignature The Presto function signature to convert. +// @return A pointer to the constructed Velox function signature. +velox::exec::FunctionSignaturePtr buildVeloxSignatureFromPrestoSignature( + const protocol::Signature& prestoSignature) { + velox::exec::FunctionSignatureBuilder signatureBuilder; + + for (const auto& typeVar : prestoSignature.typeVariableConstraints) { + signatureBuilder.typeVariable(typeVar.name); + } + + for (const auto& longVar : prestoSignature.longVariableConstraints) { + signatureBuilder.integerVariable(longVar.name); + } + signatureBuilder.returnType(prestoSignature.returnType); + + for (const auto& argType : prestoSignature.argumentTypes) { + signatureBuilder.argumentType(argType); + } + + if (prestoSignature.variableArity) { + signatureBuilder.variableArity(); + } + return signatureBuilder.build(); +} + +} // namespace + +void registerRestRemoteFunction( + const protocol::RestFunctionHandle& restFunctionHandle) { + static std::mutex registrationMutex; + static std::unordered_map registeredFunctionHandles; + static std::unordered_map + restClient; + static const std::string remoteFunctionServerRestURL = + SystemConfig::instance()->remoteFunctionServerRestURL(); + + const std::string functionId = restFunctionHandle.functionId; + + json functionHandleJson; + to_json(functionHandleJson, restFunctionHandle); + functionHandleJson["url"] = remoteFunctionServerRestURL; + const std::string serializedFunctionHandle = functionHandleJson.dump(); + + // Check if already registered (read-only, no lock needed for initial check) + { + std::lock_guard lock(registrationMutex); + auto it = registeredFunctionHandles.find(functionId); + if (it != registeredFunctionHandles.end() && + it->second == serializedFunctionHandle) { + return; + } + } + + // Get or create shared RestRemoteClient for this server URL + functions::rest::RestRemoteClientPtr remoteClient; + { + std::lock_guard lock(registrationMutex); + auto clientIt = restClient.find(remoteFunctionServerRestURL); + if (clientIt == restClient.end()) { + restClient[remoteFunctionServerRestURL] = + std::make_shared( + remoteFunctionServerRestURL); + } + remoteClient = restClient[remoteFunctionServerRestURL]; + } + + functions::rest::VeloxRemoteFunctionMetadata metadata; + + // Extract function name parts using the utility function + const std::string functionName = + getFunctionName(restFunctionHandle.functionId); + const auto parts = util::getFunctionNameParts(functionName); + const std::string schema = parts[1]; + const std::string function = parts[2]; + + const std::string functionLocation = fmt::format( + "{}/v1/functions/{}/{}/{}/{}", + remoteFunctionServerRestURL, + schema, + function, + urlEncode(restFunctionHandle.functionId), + restFunctionHandle.version); + metadata.location = functionLocation; + metadata.serdeFormat = getSerdeFormat(); + + auto veloxSignature = + buildVeloxSignatureFromPrestoSignature(restFunctionHandle.signature); + std::vector veloxSignatures = { + veloxSignature}; + + functions::rest::registerVeloxRemoteFunction( + getFunctionName(restFunctionHandle.functionId), + veloxSignatures, + metadata, + remoteClient); + + // Update registration map + { + std::lock_guard lock(registrationMutex); + registeredFunctionHandles[functionId] = serializedFunctionHandle; + } +} +} // namespace facebook::presto::functions::remote::rest diff --git a/presto-native-execution/presto_cpp/main/functions/remote/PrestoRestFunctionRegistration.h b/presto-native-execution/presto_cpp/main/functions/remote/PrestoRestFunctionRegistration.h new file mode 100644 index 000000000000..fcaf3d086f18 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/functions/remote/PrestoRestFunctionRegistration.h @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "presto_cpp/presto_protocol/presto_protocol.h" + +namespace facebook::presto::functions::remote::rest { + +void registerRestRemoteFunction( + const protocol::RestFunctionHandle& restFunctionHandle); +} // namespace facebook::presto::functions::remote::rest diff --git a/presto-native-execution/presto_cpp/main/functions/remote/RestRemoteFunction.cpp b/presto-native-execution/presto_cpp/main/functions/remote/RestRemoteFunction.cpp new file mode 100644 index 000000000000..b8f1c3cade7e --- /dev/null +++ b/presto-native-execution/presto_cpp/main/functions/remote/RestRemoteFunction.cpp @@ -0,0 +1,103 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "presto_cpp/main/functions/remote/RestRemoteFunction.h" +#include "presto_cpp/main/functions/remote/client/RestRemoteClient.h" +#include "velox/functions/remote/client/RemoteVectorFunction.h" + +using namespace facebook::velox; +namespace facebook::presto::functions::rest { +namespace { + +class RestRemoteFunction : public velox::functions::RemoteVectorFunction { + public: + RestRemoteFunction( + const std::string& functionName, + const std::vector& inputArgs, + const VeloxRemoteFunctionMetadata& metadata, + RestRemoteClientPtr restClient) + : RemoteVectorFunction(functionName, inputArgs, metadata), + location_(metadata.location), + serdeFormat_(metadata.serdeFormat), + restClient_(std::move(restClient)) {} + + protected: + std::unique_ptr + invokeRemoteFunction( + const velox::functions::remote::RemoteFunctionRequest& request) + const override { + VELOX_CHECK(restClient_, "Remote client not initialized."); + + // Clone the request payload for the REST call + auto requestBody = request.inputs()->payload()->clone(); + + auto responseBody = restClient_->invokeFunction( + location_, serdeFormat_, std::move(requestBody)); + + if (!responseBody) { + VELOX_FAIL("No response received from remote function invocation."); + } + + // Convert REST response to RemoteFunctionResponse + auto response = + std::make_unique(); + velox::functions::remote::RemoteFunctionPage result; + result.payload_ref() = std::move(*responseBody); + response->result_ref() = std::move(result); + return response; + } + + std::string remoteLocationToString() const override { + return location_; + } + + private: + const std::string location_; + const velox::functions::remote::PageFormat serdeFormat_; + const RestRemoteClientPtr restClient_; +}; + +std::shared_ptr createRestRemoteFunction( + const std::string& name, + const std::vector& inputArgs, + const core::QueryConfig& /*config*/, + const VeloxRemoteFunctionMetadata& metadata, + RestRemoteClientPtr restClient) { + return std::make_shared( + name, inputArgs, metadata, restClient); +} + +} // namespace + +void registerVeloxRemoteFunction( + const std::string& name, + const std::vector& signatures, + VeloxRemoteFunctionMetadata metadata, + RestRemoteClientPtr restClient, + bool overwrite) { + registerStatefulVectorFunction( + name, + signatures, + std::bind( + createRestRemoteFunction, + std::placeholders::_1, + std::placeholders::_2, + std::placeholders::_3, + metadata, + restClient), + std::move(metadata), + overwrite); +} + +} // namespace facebook::presto::functions::rest diff --git a/presto-native-execution/presto_cpp/main/functions/remote/RestRemoteFunction.h b/presto-native-execution/presto_cpp/main/functions/remote/RestRemoteFunction.h new file mode 100644 index 000000000000..408fc6700cc6 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/functions/remote/RestRemoteFunction.h @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "presto_cpp/main/functions/remote/client/RestRemoteClient.h" +#include "velox/functions/remote/client/RemoteVectorFunction.h" + +namespace facebook::presto::functions::rest { + +struct VeloxRemoteFunctionMetadata + : public velox::functions::RemoteVectorFunctionMetadata { + /// URL of the HTTP/REST server for remote function. + std::string location; +}; + +void registerVeloxRemoteFunction( + const std::string& name, + const std::vector& signatures, + VeloxRemoteFunctionMetadata metadata, + RestRemoteClientPtr restClient, + bool overwrite = true); + +} // namespace facebook::presto::functions::rest diff --git a/presto-native-execution/presto_cpp/main/functions/remote/client/CMakeLists.txt b/presto-native-execution/presto_cpp/main/functions/remote/client/CMakeLists.txt new file mode 100644 index 000000000000..0b714aff8d91 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/functions/remote/client/CMakeLists.txt @@ -0,0 +1,14 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_library(presto_functions_rest_client RestRemoteClient.cpp) +target_link_libraries(presto_functions_rest_client presto_http) diff --git a/presto-native-execution/presto_cpp/main/functions/remote/client/RestRemoteClient.cpp b/presto-native-execution/presto_cpp/main/functions/remote/client/RestRemoteClient.cpp new file mode 100644 index 000000000000..0d8b94710fe8 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/functions/remote/client/RestRemoteClient.cpp @@ -0,0 +1,113 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "presto_cpp/main/functions/remote/client/RestRemoteClient.h" + +#include +#include + +#include "presto_cpp/main/functions/remote/utils/ContentTypes.h" +#include "velox/common/base/Exceptions.h" +#include "velox/common/memory/Memory.h" + +using namespace facebook::velox; + +namespace facebook::presto::functions::rest { +namespace { +inline std::string getContentType(velox::functions::remote::PageFormat fmt) { + return fmt == velox::functions::remote::PageFormat::SPARK_UNSAFE_ROW + ? remote::CONTENT_TYPE_SPARK_UNSAFE_ROW + : remote::CONTENT_TYPE_PRESTO_PAGE; +} +} // namespace + +RestRemoteClient::RestRemoteClient(const std::string& url) : url_(url) { + memPool_ = memory::MemoryManager::getInstance()->addLeafPool(); + folly::Uri uri(url_); + proxygen::Endpoint endpoint(uri.host(), uri.port(), uri.scheme() == "https"); + folly::SocketAddress addr(uri.host().c_str(), uri.port(), true); + + evbThread_ = std::make_unique("rest-client"); + httpClient_ = std::make_shared( + evbThread_->getEventBase(), + nullptr, + endpoint, + addr, + requestTimeoutMs, + connectTimeoutMs, + memPool_, + nullptr); +} + +RestRemoteClient::~RestRemoteClient() { + if (httpClient_) { + evbThread_->getEventBase()->runInEventBaseThreadAndWait( + [client = std::move(httpClient_)]() mutable { client.reset(); }); + } + evbThread_.reset(); +} + +std::unique_ptr RestRemoteClient::invokeFunction( + const std::string& fullUrl, + velox::functions::remote::PageFormat serdeFormat, + std::unique_ptr requestPayload) const { + try { + folly::Uri uri(fullUrl); + const std::string contentType = getContentType(serdeFormat); + auto message = std::make_unique(); + message->setMethod(proxygen::HTTPMethod::POST); + message->setURL(uri.path()); + message->setHTTPVersion(1, 1); + message->getHeaders().add("Content-Type", contentType); + message->getHeaders().add("Accept", contentType); + + requestPayload->coalesce(); + std::string requestBody = requestPayload->moveToFbString().toStdString(); + + auto sendFuture = httpClient_->sendRequest(*message, requestBody); + sendFuture.wait(); + + VELOX_CHECK( + sendFuture.hasValue(), + "Invalid response returned from HTTP request to {}.", + uri.host()); + + std::unique_ptr resp = std::move(sendFuture).get(); + + if (!resp) { + VELOX_FAIL( + "Null response object returned from HTTP request to {}.", uri.host()); + } + + if (resp->hasError()) { + VELOX_FAIL("HTTP error: {}", resp->error()); + } + + int status = resp->headers()->getStatusCode(); + if (status < http::kHttpOk || status >= http::kHttpMultipleChoices) { + VELOX_FAIL( + "Server responded with status {}. Body: '{}'. URL: {}", + status, + resp->dumpBodyChain(), + fullUrl); + } + + return folly::IOBuf::copyBuffer(resp->dumpBodyChain()); + } catch (const std::exception& ex) { + VELOX_FAIL("HTTP invocation failed for URL {}: {}", fullUrl, ex.what()); + } + return nullptr; +} + +} // namespace facebook::presto::functions::rest diff --git a/presto-native-execution/presto_cpp/main/functions/remote/client/RestRemoteClient.h b/presto-native-execution/presto_cpp/main/functions/remote/client/RestRemoteClient.h new file mode 100644 index 000000000000..9fe89944c720 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/functions/remote/client/RestRemoteClient.h @@ -0,0 +1,53 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include "presto_cpp/main/common/Configs.h" +#include "presto_cpp/main/http/HttpClient.h" +#include "velox/functions/remote/if/gen-cpp2/RemoteFunction_types.h" + +namespace facebook::presto::functions::rest { + +class RestRemoteClient { + public: + RestRemoteClient(const std::string& url); + + ~RestRemoteClient(); + + std::unique_ptr invokeFunction( + const std::string& fullUrl, + velox::functions::remote::PageFormat serdeFormat, + std::unique_ptr requestPayload) const; + + private: + const std::string url_; + std::unique_ptr evbThread_; + std::shared_ptr httpClient_; + std::shared_ptr memPool_; + + const std::chrono::milliseconds requestTimeoutMs = + std::chrono::duration_cast( + SystemConfig::instance()->exchangeRequestTimeoutMs()); + + const std::chrono::milliseconds connectTimeoutMs = + std::chrono::duration_cast( + SystemConfig::instance()->exchangeConnectTimeoutMs()); +}; + +using RestRemoteClientPtr = std::shared_ptr; + +} // namespace facebook::presto::functions::rest diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/CMakeLists.txt b/presto-native-execution/presto_cpp/main/functions/remote/tests/CMakeLists.txt new file mode 100644 index 000000000000..59baa1eb24fe --- /dev/null +++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/CMakeLists.txt @@ -0,0 +1,30 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if(${PRESTO_ENABLE_TESTING}) + add_subdirectory(server) +endif() + +add_executable(presto_functions_rest_test RemoteFunctionRestTest.cpp) + +add_test(presto_functions_rest_test presto_functions_rest_test) + +target_link_libraries( + presto_functions_rest_test + presto_functions_rest_client + presto_functions_rest_server + presto_functions_remote + velox_functions_test_lib + GTest::gmock + GTest::gtest + GTest::gtest_main +) diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/RemoteFunctionRestTest.cpp b/presto-native-execution/presto_cpp/main/functions/remote/tests/RemoteFunctionRestTest.cpp new file mode 100644 index 000000000000..ddf233b7647d --- /dev/null +++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/RemoteFunctionRestTest.cpp @@ -0,0 +1,299 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include +#include + +#include "presto_cpp/main/functions/remote/RestRemoteFunction.h" +#include "presto_cpp/main/functions/remote/client/RestRemoteClient.h" +#include "presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestService.h" +#include "presto_cpp/main/functions/remote/tests/server/RestFunctionRegistry.h" +#include "presto_cpp/main/functions/remote/tests/server/examples/RemoteDoubleDivHandler.h" +#include "presto_cpp/main/functions/remote/tests/server/examples/RemoteFibonacciHandler.h" +#include "presto_cpp/main/functions/remote/tests/server/examples/RemoteInverseCdfHandler.h" +#include "presto_cpp/main/functions/remote/tests/server/examples/RemoteRemoveCharHandler.h" +#include "presto_cpp/main/functions/remote/tests/server/examples/RemoteStrLenHandler.h" +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/exec/tests/utils/PortUtil.h" +#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" +#include "velox/type/fbhive/HiveTypeParser.h" + +using ::facebook::velox::test::assertEqualVectors; +using namespace facebook::velox; +namespace facebook::presto::functions::rest::test { +namespace { + +class RemoteFunctionRestTest + : public velox::functions::test::FunctionBaseTest, + public testing::WithParamInterface { + public: + void SetUp() override { + auto servicePort = exec::test::getFreePort(); + location_ = fmt::format(kHostAddressTemplate_, servicePort); + wrongLocation_ = + fmt::format(kHostAddressTemplate_, exec::test::getFreePort()); + + restClient_ = std::make_shared(location_); + wrongRestClient_ = std::make_shared(wrongLocation_); + + initializeServer(servicePort); + registerRemoteFunctions(); + } + + ~RemoteFunctionRestTest() override { + if (serverThread_ && serverThread_->joinable()) { + ioc_.stop(); + serverThread_->join(); + } + } + + private: + // Registers a remote function by creating a handler instance and registering + // it on both the server and client side. The handler provides its own input + // and output types, eliminating the need for manual type specification. + template + void registerRemoteFunctionHelper( + const std::string& functionName, + const std::string& baseLocation, + RestRemoteClientPtr client) const { + auto handler = std::make_shared(); + + auto inputTypes = handler->getInputTypes(); + auto outputType = handler->getOutputType(); + + auto signatureBuilder = exec::FunctionSignatureBuilder(); + signatureBuilder.returnType(outputType->toString()); + for (const auto& childType : inputTypes->children()) { + signatureBuilder.argumentType(childType->toString()); + } + + RestFunctionRegistry::getInstance().registerFunction(functionName, handler); + + VeloxRemoteFunctionMetadata metadata; + metadata.serdeFormat = GetParam(); + metadata.location = baseLocation + "/" + functionName; + registerVeloxRemoteFunction( + functionName, {signatureBuilder.build()}, metadata, client); + } + + void registerRemoteFunctions() const { + registerRemoteFunctionHelper( + "remote_fibonacci", location_, restClient_); + registerRemoteFunctionHelper( + "remote_strlen", location_, restClient_); + registerRemoteFunctionHelper( + "remote_remove_char", location_, restClient_); + registerRemoteFunctionHelper( + "remote_inverse_cdf", location_, restClient_); + registerRemoteFunctionHelper( + "remote_divide", location_, restClient_); + registerRemoteFunctionHelper( + "remote_wrong_port", wrongLocation_, wrongRestClient_); + + // Register a fake function handler whose logic is intentionally not + // implemented in the server. This simulates a failure scenario for testing + // purposes. + auto roundSignatures = {exec::FunctionSignatureBuilder() + .returnType("integer") + .argumentType("integer") + .build()}; + VeloxRemoteFunctionMetadata metadata; + metadata.serdeFormat = GetParam(); + metadata.location = location_ + "/remote_round"; + registerVeloxRemoteFunction( + "remote_round", roundSignatures, metadata, restClient_); + } + + void initializeServer(uint16_t servicePort) { + serverThread_ = std::make_unique([this, servicePort]() { + std::make_shared( + ioc_, + boost::asio::ip::tcp::endpoint( + boost::asio::ip::make_address(kServiceHost), servicePort)) + ->run(); + ioc_.run(); + }); + + VELOX_CHECK( + waitForRunning(servicePort), "Unable to initialize HTTP server."); + } + + bool waitForRunning(uint16_t servicePort) const { + for (size_t i = 0; i < 100; ++i) { + using boost::asio::ip::tcp; + boost::asio::io_context io_context; + + tcp::socket socket(io_context); + tcp::resolver resolver(io_context); + + try { + boost::asio::connect( + socket, + resolver.resolve(kServiceHost, std::to_string(servicePort))); + return true; + } catch (std::exception& e) { + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + } + } + return false; + } + + std::unique_ptr serverThread_; + boost::asio::io_context ioc_{1}; + static constexpr auto kHostAddressTemplate_ = "http://127.0.0.1:{}"; + static constexpr auto kServiceHost = "127.0.0.1"; + + std::string location_; + std::string wrongLocation_; + RestRemoteClientPtr restClient_; + RestRemoteClientPtr wrongRestClient_; +}; + +TEST_P(RemoteFunctionRestTest, connectionError) { + auto numeratorVector = makeFlatVector({0, 1, 4, 9, 16, 25, -25}); + auto denominatorVector = makeFlatVector({0, 1, 2, 3, 4, 0, 2}); + auto data = makeRowVector({numeratorVector, denominatorVector}); + VELOX_ASSERT_THROW( + evaluate>("remote_wrong_port(c0,c1)", data), + "HTTP invocation failed for URL"); +} + +TEST_P(RemoteFunctionRestTest, fibonacci) { + auto inputVector = makeFlatVector({10, 20}); + auto results = evaluate>( + "remote_fibonacci(c0)", makeRowVector({inputVector})); + + auto expected = makeFlatVector({55, 6765}); + assertEqualVectors(expected, results); +} + +TEST_P(RemoteFunctionRestTest, stringLength) { + auto inputVector = + makeFlatVector({"hello", "from", "remote", "server"}); + auto results = evaluate>( + "remote_strlen(c0)", makeRowVector({inputVector})); + + auto expected = makeFlatVector({5, 4, 6, 6}); + assertEqualVectors(expected, results); +} + +TEST_P(RemoteFunctionRestTest, removeChar) { + auto input = makeFlatVector( + {"hello from remote server", + "testing remote server", + "My file, named 'data_report#2.csv', is located in the folder: C:\\Users\\User\\Documents! It's quite large (~1.5GB)."}); + auto charToRemove = makeFlatVector({"o", "e", "c"}); + auto results = evaluate>( + "remote_remove_char(c0,c1)", makeRowVector({input, charToRemove})); + + auto expected = makeFlatVector( + {"hell frm remte server", + "tsting rmot srvr", + "My file, named 'data_report#2.sv', is loated in the folder: C:\\Users\\User\\Douments! It's quite large (~1.5GB)."}); + assertEqualVectors(expected, results); +} + +TEST_P(RemoteFunctionRestTest, tryException) { + // remote_divide throws if denominator is 0. + auto numeratorVector = makeFlatVector({0, 1, 4, 9, 16, 25, -25}); + auto denominatorVector = makeFlatVector({0, 1, 2, 3, 4, 0, 2}); + auto data = makeRowVector({numeratorVector, denominatorVector}); + auto results = evaluate>("remote_divide(c0, c1)", data); + + ASSERT_EQ(results->size(), 7); + auto expected = makeFlatVector({0, 1, 2, 3, 4, 0, -12.5}); + expected->setNull(0, true); + expected->setNull(5, true); + + assertEqualVectors(expected, results); +} + +TEST_P(RemoteFunctionRestTest, inverseCdf) { + auto pVector = makeFlatVector({0.95, 0.95, 0.50, 0.10}); + auto nuVector = makeFlatVector({4, 1, 10, 2}); + auto data = makeRowVector({pVector, nuVector}); + auto results = + evaluate>("remote_inverse_cdf(c0, c1)", data); + + ASSERT_EQ(results->size(), 4); + auto expected = makeFlatVector({9.49, 3.84, 9.34, 0.21}); + + assertEqualVectors(expected, results); +} + +TEST_P(RemoteFunctionRestTest, inverseCdfException) { + // p < 0, p > 1, nu <= 0 - these should throw exceptions + auto pVector = makeFlatVector({-0.1}); + auto nuVector = makeFlatVector({4}); + auto data = makeRowVector({pVector, nuVector}); + + VELOX_ASSERT_THROW( + evaluate>("remote_inverse_cdf(c0, c1)", data), + "inverse_chi_squared_cdf: p must be in (0,1)"); + + // Test p > 1 + pVector = makeFlatVector({1.1}); + nuVector = makeFlatVector({1}); + data = makeRowVector({pVector, nuVector}); + + VELOX_ASSERT_THROW( + evaluate>("remote_inverse_cdf(c0, c1)", data), + "inverse_chi_squared_cdf: p must be in (0,1)"); + + // Test nu <= 0 + pVector = makeFlatVector({0.5}); + nuVector = makeFlatVector({0}); + data = makeRowVector({pVector, nuVector}); + + VELOX_ASSERT_THROW( + evaluate>("remote_inverse_cdf(c0, c1)", data), + "inverse_chi_squared_cdf: degrees of freedom must be > 0"); + + // Test nu < 0 + pVector = makeFlatVector({0.95}); + nuVector = makeFlatVector({-2}); + data = makeRowVector({pVector, nuVector}); + + VELOX_ASSERT_THROW( + evaluate>("remote_inverse_cdf(c0, c1)", data), + "inverse_chi_squared_cdf: degrees of freedom must be > 0"); +} + +TEST_P(RemoteFunctionRestTest, functionNotAvailable) { + auto inputVector = makeFlatVector({-10, -20}); + VELOX_ASSERT_THROW( + evaluate>( + "remote_round(c0)", makeRowVector({inputVector})), + "Server responded with status 400. Body: 'Function 'remote_round' is not available.'"); +} + +VELOX_INSTANTIATE_TEST_SUITE_P( + RemoteFunctionRestTestFixture, + RemoteFunctionRestTest, + ::testing::Values( + velox::functions::remote::PageFormat::PRESTO_PAGE, + velox::functions::remote::PageFormat::SPARK_UNSAFE_ROW)); + +} // namespace +} // namespace facebook::presto::functions::rest::test + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + folly::Init init{&argc, &argv, false}; + return RUN_ALL_TESTS(); +} diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/CMakeLists.txt b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/CMakeLists.txt new file mode 100644 index 000000000000..9cd7cd3bcc41 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/CMakeLists.txt @@ -0,0 +1,14 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_library(presto_functions_rest_server RemoteFunctionRestService.cpp RestFunctionRegistry.cpp) +target_link_libraries(presto_functions_rest_server velox_presto_serializer) diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestHandler.h b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestHandler.h new file mode 100644 index 000000000000..b33dac334f86 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestHandler.h @@ -0,0 +1,77 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include "velox/vector/VectorStream.h" + +namespace facebook::presto::functions::rest { + +class RemoteFunctionRestHandler { + public: + RemoteFunctionRestHandler() = default; + + virtual ~RemoteFunctionRestHandler() = default; + + virtual velox::RowTypePtr getInputTypes() const = 0; + virtual velox::TypePtr getOutputType() const = 0; + + folly::IOBuf handleRequest( + std::unique_ptr inputBuffer, + velox::VectorSerde* serde, + velox::memory::MemoryPool* pool, + std::string& errorMessage) { + auto inputTypes = getInputTypes(); + auto outputType = getOutputType(); + + auto inputVector = IOBufToRowVector(*inputBuffer, inputTypes, *pool, serde); + + VELOX_CHECK_EQ( + inputVector->childrenSize(), + inputTypes->children().size(), + "Mismatched number of columns for remote function handler."); + + const auto numRows = inputVector->size(); + auto resultVector = velox::BaseVector::create(outputType, numRows, pool); + + compute(inputVector, resultVector, errorMessage); + + if (!errorMessage.empty()) { + return folly::IOBuf(); + } + + // Wrap the result in a RowVector to send back. + auto outputRowVector = std::make_shared( + pool, + velox::ROW({outputType}), + velox::BufferPtr(), + numRows, + std::vector{resultVector}); + + auto payload = rowVectorToIOBuf( + outputRowVector, outputRowVector->size(), *pool, serde); + + return payload; + } + + protected: + // Core computation function to be implemented by subclasses. + virtual void compute( + const velox::RowVectorPtr& inputVector, + const velox::VectorPtr& resultVector, + std::string& errorMessage) = 0; +}; + +} // namespace facebook::presto::functions::rest diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestService.cpp b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestService.cpp new file mode 100644 index 000000000000..2c944687f1f3 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestService.cpp @@ -0,0 +1,314 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestService.h" + +#include + +#include "presto_cpp/main/functions/remote/tests/server/RestFunctionRegistry.h" +#include "presto_cpp/main/functions/remote/utils/ContentTypes.h" +#include "velox/serializers/PrestoSerializer.h" +#include "velox/serializers/UnsafeRowSerializer.h" + +using namespace facebook::velox; +namespace facebook::presto::functions::rest { + +RestSession::RestSession(boost::asio::ip::tcp::socket socket) + : socket_(std::move(socket)), + pool_(memory::memoryManager()->addLeafPool()) {} + +void RestSession::run() { + doRead(); +} + +void RestSession::doRead() { + auto self = shared_from_this(); + boost::beast::http::async_read( + socket_, + buffer_, + req_, + [self](boost::beast::error_code ec, size_t bytes_transferred) { + self->onRead(ec, bytes_transferred); + }); +} + +void RestSession::onRead( + boost::beast::error_code ec, + std::size_t bytes_transferred) { + boost::ignore_unused(bytes_transferred); + + if (ec == boost::beast::http::error::end_of_stream) { + return doClose(); + } + + if (ec) { + LOG(ERROR) << "Read error: " << ec.message(); + return; + } + + handleRequest(std::move(req_)); +} + +void RestSession::handleRequest( + const boost::beast::http::request& req) { + res_.version(req.version()); + res_.set(boost::beast::http::field::server, BOOST_BEAST_VERSION_STRING); + + if (!ensurePostMethod(req)) { + return; + } + + if (!ensureValidContentType(req)) { + return; + } + + if (!ensureValidAcceptHeader(req)) { + return; + } + + std::string functionName; + if (!extractFunctionName(req, functionName)) { + return; + } + + try { + auto& registry = RestFunctionRegistry::getInstance(); + auto handler = registry.getFunction(functionName); + if (!handler) { + sendResponse( + boost::beast::http::status::bad_request, + plainText_, + fmt::format("Function '{}' is not available.", functionName), + false); + return; + } + + std::unique_ptr serde; + if (contentType_ == remote::CONTENT_TYPE_PRESTO_PAGE) { + serde = std::make_unique(); + } else { + serde = std::make_unique(); + } + auto inputBuffer = folly::IOBuf::copyBuffer(req.body()); + + std::string errorMessage; + auto outputBuffer = handler->handleRequest( + std::move(inputBuffer), serde.get(), pool_.get(), errorMessage); + if (!errorMessage.empty()) { + sendResponse( + boost::beast::http::status::internal_server_error, + plainText_, + errorMessage, + true); + return; + } + sendSuccessResponse(std::move(outputBuffer)); + } catch (const std::exception& ex) { + handleException(ex); + } +} + +void RestSession::sendResponse( + boost::beast::http::status status, + const std::string& contentType, + const std::string& body, + bool closeConnection) { + res_.result(status); + res_.set(boost::beast::http::field::content_type, contentType); + res_.body() = body; + res_.prepare_payload(); + + auto self = shared_from_this(); + boost::beast::http::async_write( + socket_, + res_, + [self, closeConnection]( + boost::beast::error_code ec, std::size_t bytes_transferred) { + self->onWrite(closeConnection, ec, bytes_transferred); + }); +} + +bool RestSession::ensurePostMethod( + const boost::beast::http::request& req) { + if (req.method() == boost::beast::http::verb::post) { + return true; + } + auto msg = fmt::format( + "Only POST method is allowed. Method used: {}", + std::string(req.method_string())); + sendResponse( + boost::beast::http::status::method_not_allowed, plainText_, msg, true); + return false; +} + +bool RestSession::ensureValidContentType( + const boost::beast::http::request& req) { + auto contentType = req[boost::beast::http::field::content_type]; + if (!contentType.empty() && + (contentType == remote::CONTENT_TYPE_SPARK_UNSAFE_ROW || + contentType == remote::CONTENT_TYPE_PRESTO_PAGE)) { + contentType_ = contentType; + return true; + } + + auto msg = fmt::format( + "Unsupported Content-Type: '{}'. Expecting '{}' or '{}'.", + std::string(contentType), + remote::CONTENT_TYPE_PRESTO_PAGE, + remote::CONTENT_TYPE_SPARK_UNSAFE_ROW); + sendResponse( + boost::beast::http::status::unsupported_media_type, + plainText_, + msg, + true); + return false; +} + +bool RestSession::ensureValidAcceptHeader( + const boost::beast::http::request& req) { + auto acceptHeader = req[boost::beast::http::field::accept]; + if (!acceptHeader.empty() && + (acceptHeader == remote::CONTENT_TYPE_PRESTO_PAGE || + acceptHeader == remote::CONTENT_TYPE_SPARK_UNSAFE_ROW)) { + return true; + } + + auto msg = fmt::format( + "Unsupported Accept header: '{}'. Expecting '{}' or '{}'.", + std::string(acceptHeader), + remote::CONTENT_TYPE_PRESTO_PAGE, + remote::CONTENT_TYPE_SPARK_UNSAFE_ROW); + sendResponse( + boost::beast::http::status::not_acceptable, plainText_, msg, true); + return false; +} + +bool RestSession::extractFunctionName( + const boost::beast::http::request& + requestBody, + std::string& functionName) { + std::string path = std::string(requestBody.target()); + std::vector pathComponents; + folly::split('/', path, pathComponents); + + if (pathComponents.size() != 2) { + sendResponse( + boost::beast::http::status::bad_request, + plainText_, + "Invalid request path", + true); + return false; + } + + // pathComponents[1] contains the function name + functionName = pathComponents[1]; + return true; +} + +void RestSession::handleException(const std::exception& ex) { + sendResponse( + boost::beast::http::status::internal_server_error, + plainText_, + ex.what(), + true); +} + +void RestSession::sendSuccessResponse(folly::IOBuf&& payload) { + sendResponse( + boost::beast::http::status::ok, + contentType_, + payload.moveToFbString().toStdString(), + false); +} + +void RestSession::onWrite( + bool close, + boost::beast::error_code ec, + std::size_t bytes_transferred) { + boost::ignore_unused(bytes_transferred); + + if (ec) { + LOG(ERROR) << "Write error: " << ec.message(); + return; + } + + if (close) { + return doClose(); + } + + req_ = {}; + + doRead(); +} + +void RestSession::doClose() { + boost::beast::error_code ec; + socket_.shutdown(boost::asio::ip::tcp::socket::shutdown_send, ec); +} + +RestListener::RestListener( + boost::asio::io_context& ioc, + boost::asio::ip::tcp::endpoint endpoint) + : ioc_(ioc), acceptor_(ioc) { + boost::beast::error_code ec; + + acceptor_.open(endpoint.protocol(), ec); + if (ec) { + LOG(ERROR) << "Open error: " << ec.message(); + return; + } + + acceptor_.set_option(boost::asio::socket_base::reuse_address(true), ec); + if (ec) { + LOG(ERROR) << "Set_option error: " << ec.message(); + return; + } + + acceptor_.bind(endpoint, ec); + if (ec) { + LOG(ERROR) << "Bind error: " << ec.message(); + return; + } + + acceptor_.listen(boost::asio::socket_base::max_listen_connections, ec); + if (ec) { + LOG(ERROR) << "Listen error: " << ec.message(); + return; + } +} + +void RestListener::run() { + doAccept(); +} + +void RestListener::doAccept() { + acceptor_.async_accept( + [self = shared_from_this()]( + boost::beast::error_code ec, boost::asio::ip::tcp::socket socket) { + self->onAccept(ec, std::move(socket)); + }); +} + +void RestListener::onAccept( + boost::beast::error_code ec, + boost::asio::ip::tcp::socket socket) { + if (ec) { + LOG(ERROR) << "Accept error: " << ec.message(); + } else { + std::make_shared(std::move(socket))->run(); + } + doAccept(); +} + +} // namespace facebook::presto::functions::rest diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestService.h b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestService.h new file mode 100644 index 000000000000..e24d889dd64f --- /dev/null +++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestService.h @@ -0,0 +1,125 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include "presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestHandler.h" +#include "presto_cpp/main/functions/remote/tests/server/RestFunctionRegistry.h" + +namespace facebook::presto::functions::rest { + +/// @brief Manages an individual HTTP session. +/// Handles reading HTTP requests, processing them, and sending responses. +/// Inspired by the reference implementation described in: +/// https://medium.com/@AlexanderObregon/building-restful-apis-with-c-4c8ac63fe8a7 +class RestSession : public std::enable_shared_from_this { + public: + RestSession(boost::asio::ip::tcp::socket socket); + + /// Starts the session by initiating a read operation. + void run(); + + private: + // Initiates an asynchronous read operation. + void doRead(); + + // Called when a read operation completes. + void onRead(boost::beast::error_code ec, std::size_t bytes_transferred); + + // Processes the HTTP request and prepares a response. + void handleRequest( + const boost::beast::http::request& req); + + // Called when a write operation completes. + void onWrite( + bool close, + boost::beast::error_code ec, + std::size_t bytes_transferred); + + // Closes the socket connection. + void doClose(); + + // Helper to handle exceptions: logs error and sends HTTP 500 + void handleException(const std::exception& ex); + + // Helper to extract function name from /{functionName} + bool extractFunctionName( + const boost::beast::http::request& + requestBody, + std::string& functionName); + + // Sends a response with status, contentType, and body. Then calls + // async_write. + void sendResponse( + boost::beast::http::status status, + const std::string& contentType, + const std::string& body, + bool closeConnection); + + // Helper to ensure the HTTP method is POST, else sends an error response. + bool ensurePostMethod( + const boost::beast::http::request& req); + + // Helper to ensure Content-Type is either "application/X-presto-pages" or + // "application/X-spark-unsafe-row". Otherwise sends an error response. + bool ensureValidContentType( + const boost::beast::http::request& req); + + // Helper to ensure Accept is either "application/X-presto-pages" or + // "application/X-spark-unsafe-row". Otherwise sends an error response. + bool ensureValidAcceptHeader( + const boost::beast::http::request& req); + + // Sends a success response with the given payload. + void sendSuccessResponse(folly::IOBuf&& payload); + + boost::asio::ip::tcp::socket socket_; + boost::beast::flat_buffer buffer_; + boost::beast::http::request req_; + boost::beast::http::response res_; + std::shared_ptr pool_; + std::string contentType_; + static constexpr std::string plainText_ = "text/plain"; +}; + +/// @brief Listens for incoming TCP connections and creates sessions. +/// Sets up a TCP acceptor to listen for client connections, +/// creating a new session for each accepted connection. +class RestListener : public std::enable_shared_from_this { + public: + RestListener( + boost::asio::io_context& ioc, + boost::asio::ip::tcp::endpoint endpoint); + + /// Starts accepting incoming connections. + void run(); + + private: + // Initiates an asynchronous accept operation. + void doAccept(); + + // Called when an accept operation completes. + void onAccept( + boost::beast::error_code ec, + boost::asio::ip::tcp::socket socket); + + boost::asio::io_context& ioc_; + boost::asio::ip::tcp::acceptor acceptor_; +}; + +} // namespace facebook::presto::functions::rest diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RestFunctionRegistry.cpp b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RestFunctionRegistry.cpp new file mode 100644 index 000000000000..4cd56c209cd8 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RestFunctionRegistry.cpp @@ -0,0 +1,69 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "presto_cpp/main/functions/remote/tests/server/RestFunctionRegistry.h" + +#include + +namespace facebook::presto::functions::rest { + +RestFunctionRegistry& RestFunctionRegistry::getInstance() { + static RestFunctionRegistry instance; + return instance; +} + +bool RestFunctionRegistry::registerFunction( + const std::string& functionName, + std::shared_ptr handler) { + std::lock_guard lock(mutex_); + + auto it = functionHandlers_.find(functionName); + bool replacing = (it != functionHandlers_.end()); + + if (replacing) { + LOG(WARNING) << "Function handler for '" << functionName + << "' is being replaced."; + } + + functionHandlers_[functionName] = std::move(handler); + return replacing; +} + +bool RestFunctionRegistry::unregisterFunction(const std::string& functionName) { + std::lock_guard lock(mutex_); + return functionHandlers_.erase(functionName) > 0; +} + +std::shared_ptr RestFunctionRegistry::getFunction( + const std::string& functionName) const { + std::lock_guard lock(mutex_); + + auto it = functionHandlers_.find(functionName); + if (it != functionHandlers_.end()) { + return it->second; + } + return nullptr; +} + +bool RestFunctionRegistry::hasFunction(const std::string& functionName) const { + std::lock_guard lock(mutex_); + return functionHandlers_.find(functionName) != functionHandlers_.end(); +} + +void RestFunctionRegistry::clear() { + std::lock_guard lock(mutex_); + functionHandlers_.clear(); +} + +} // namespace facebook::presto::functions::rest diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RestFunctionRegistry.h b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RestFunctionRegistry.h new file mode 100644 index 000000000000..4284e3c08670 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RestFunctionRegistry.h @@ -0,0 +1,77 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include "presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestHandler.h" + +namespace facebook::presto::functions::rest { + +/// @brief Registry for remote function REST handlers. +/// Provides a centralized location for registering and looking up function +/// handlers by name. This follows the same pattern as WindowFunctionRegistry +/// in Velox. +class RestFunctionRegistry { + public: + /// Returns the singleton instance of the registry. + static RestFunctionRegistry& getInstance(); + + /// Registers a function handler for a given function name. + /// If a handler with the same name already exists, it will be replaced. + /// @param functionName The name of the function to register + /// @param handler The handler implementation for the function + /// @return true if a handler was replaced, false if this is a new + /// registration + bool registerFunction( + const std::string& functionName, + std::shared_ptr handler); + + /// Unregisters a function handler by name. + /// @param functionName The name of the function to unregister + /// @return true if a handler was found and removed, false otherwise + bool unregisterFunction(const std::string& functionName); + + /// Looks up a function handler by name. + /// @param functionName The name of the function to look up + /// @return The handler if found, nullptr otherwise + std::shared_ptr getFunction( + const std::string& functionName) const; + + /// Checks if a function is registered. + /// @param functionName The name of the function to check + /// @return true if the function is registered, false otherwise + bool hasFunction(const std::string& functionName) const; + + /// Clears all registered functions. + /// Useful for testing scenarios. + void clear(); + + private: + RestFunctionRegistry() = default; + ~RestFunctionRegistry() = default; + + // Prevent copying and assignment + RestFunctionRegistry(const RestFunctionRegistry&) = delete; + RestFunctionRegistry& operator=(const RestFunctionRegistry&) = delete; + + mutable std::mutex mutex_; + std::unordered_map> + functionHandlers_; +}; + +} // namespace facebook::presto::functions::rest diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteDoubleDivHandler.h b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteDoubleDivHandler.h new file mode 100644 index 000000000000..b55c4641ab46 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteDoubleDivHandler.h @@ -0,0 +1,61 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestHandler.h" + +namespace facebook::presto::functions::rest { + +class RemoteDoubleDivHandler : public RemoteFunctionRestHandler { + public: + RemoteDoubleDivHandler() = default; + + velox::RowTypePtr getInputTypes() const override { + return velox::ROW({"c0", "c1"}, {velox::DOUBLE(), velox::DOUBLE()}); + } + + velox::TypePtr getOutputType() const override { + return velox::DOUBLE(); + } + + protected: + void compute( + const velox::RowVectorPtr& inputVector, + const velox::VectorPtr& resultVector, + std::string& /*errorMessage*/) override { + auto numerator = inputVector->childAt(0)->asFlatVector(); + auto denominator = inputVector->childAt(1)->asFlatVector(); + auto outFlat = resultVector->asFlatVector(); + + const auto numRows = inputVector->size(); + for (velox::vector_size_t i = 0; i < numRows; ++i) { + // If either input is null, output is null. + if (numerator->isNullAt(i) || denominator->isNullAt(i)) { + outFlat->setNull(i, true); + } else { + double numVal = numerator->valueAt(i); + double denVal = denominator->valueAt(i); + // If denominator is zero, produce a null. + if (denVal == 0.0) { + outFlat->setNull(i, true); + } else { + outFlat->set(i, numVal / denVal); + } + } + } + } +}; + +} // namespace facebook::presto::functions::rest diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteFibonacciHandler.h b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteFibonacciHandler.h new file mode 100644 index 000000000000..8de2bd969404 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteFibonacciHandler.h @@ -0,0 +1,54 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include "presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestHandler.h" + +namespace facebook::presto::functions::rest { + +class RemoteFibonacciHandler : public RemoteFunctionRestHandler { + public: + RemoteFibonacciHandler() = default; + + velox::RowTypePtr getInputTypes() const override { + return velox::ROW({"c0"}, {velox::BIGINT()}); + } + + velox::TypePtr getOutputType() const override { + return velox::BIGINT(); + } + + protected: + void compute( + const velox::RowVectorPtr& inputVector, + const velox::VectorPtr& resultVector, + std::string& /*errorMessage*/) override { + auto numFlat = inputVector->childAt(0)->asFlatVector(); + auto outFlat = resultVector->asFlatVector(); + + const auto numRows = inputVector->size(); + for (velox::vector_size_t i = 0; i < numRows; ++i) { + if (numFlat->isNullAt(i)) { + outFlat->setNull(i, true); + } else { + int64_t num = numFlat->valueAt(i); + outFlat->set(i, boost::math::fibonacci(num)); + } + } + } +}; + +} // namespace facebook::presto::functions::rest diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteInverseCdfHandler.h b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteInverseCdfHandler.h new file mode 100644 index 000000000000..bd65d34de129 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteInverseCdfHandler.h @@ -0,0 +1,76 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include "presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestHandler.h" + +namespace facebook::presto::functions::rest { +namespace { +inline double inverse_chi_squared_cdf(double p, double nu) { + if (p <= 0.0 || p >= 1.0) { + throw std::domain_error("inverse_chi_squared_cdf: p must be in (0,1)"); + } + if (nu <= 0.0) { + throw std::domain_error( + "inverse_chi_squared_cdf: degrees of freedom must be > 0"); + } + + const boost::math::chi_squared_distribution chi2(nu); + double result = boost::math::quantile(chi2, p); + return std::round(result * 100.0) / 100.0; +} +} // namespace + +class RemoteInverseCdfHandler : public RemoteFunctionRestHandler { + public: + RemoteInverseCdfHandler() = default; + + velox::RowTypePtr getInputTypes() const override { + return velox::ROW({"c0", "c1"}, {velox::DOUBLE(), velox::DOUBLE()}); + } + + velox::TypePtr getOutputType() const override { + return velox::DOUBLE(); + } + + protected: + void compute( + const velox::RowVectorPtr& inputVector, + const velox::VectorPtr& resultVector, + std::string& errorMessage) override { + auto p = inputVector->childAt(0)->asFlatVector(); + auto nu = inputVector->childAt(1)->asFlatVector(); + auto outFlat = resultVector->asFlatVector(); + + const auto numRows = inputVector->size(); + for (velox::vector_size_t i = 0; i < numRows; ++i) { + // If either input is null, output is null. + if (p->isNullAt(i) || nu->isNullAt(i)) { + outFlat->setNull(i, true); + } else { + try { + double pVal = p->valueAt(i); + double nuVal = nu->valueAt(i); + outFlat->set(i, inverse_chi_squared_cdf(pVal, nuVal)); + } catch (const std::domain_error& ex) { + errorMessage = ex.what(); + } + } + } + } +}; + +} // namespace facebook::presto::functions::rest diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteRemoveCharHandler.h b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteRemoveCharHandler.h new file mode 100644 index 000000000000..b79818925812 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteRemoveCharHandler.h @@ -0,0 +1,64 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestHandler.h" + +namespace facebook::presto::functions::rest { + +class RemoteRemoveCharHandler final : public RemoteFunctionRestHandler { + public: + RemoteRemoveCharHandler() = default; + + velox::RowTypePtr getInputTypes() const override { + return velox::ROW({"c0", "c1"}, {velox::VARCHAR(), velox::VARCHAR()}); + } + + velox::TypePtr getOutputType() const override { + return velox::VARCHAR(); + } + + protected: + void compute( + const velox::RowVectorPtr& inputVector, + const velox::VectorPtr& resultVector, + std::string& /*errorMessage*/) override { + auto inputFlat = inputVector->childAt(0)->asFlatVector(); + auto removeFlat = + inputVector->childAt(1)->asFlatVector(); + auto outFlat = resultVector->asFlatVector(); + + const auto numRows = inputVector->size(); + + for (velox::vector_size_t i = 0; i < numRows; ++i) { + if (inputFlat->isNullAt(i) || removeFlat->isNullAt(i)) { + outFlat->setNull(i, true); + continue; + } + std::string src( + inputFlat->valueAt(i).data(), inputFlat->valueAt(i).size()); + const auto removeView = removeFlat->valueAt(i); + if (removeView.empty()) { + outFlat->set(i, velox::StringView(src)); + continue; + } + const char ch = removeView.data()[0]; + src.erase(std::remove(src.begin(), src.end(), ch), src.end()); + outFlat->set(i, velox::StringView(src)); + } + } +}; + +} // namespace facebook::presto::functions::rest diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteStrLenHandler.h b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteStrLenHandler.h new file mode 100644 index 000000000000..980a50937048 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteStrLenHandler.h @@ -0,0 +1,52 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestHandler.h" + +namespace facebook::presto::functions::rest { +class RemoteStrLenHandler : public RemoteFunctionRestHandler { + public: + RemoteStrLenHandler() = default; + + velox::RowTypePtr getInputTypes() const override { + return velox::ROW({"c0"}, {velox::VARCHAR()}); + } + + velox::TypePtr getOutputType() const override { + return velox::INTEGER(); + } + + protected: + void compute( + const velox::RowVectorPtr& inputVector, + const velox::VectorPtr& resultVector, + std::string& errorMessage) override { + auto inputFlat = inputVector->childAt(0)->asFlatVector(); + auto outFlat = resultVector->asFlatVector(); + const auto numRows = inputVector->size(); + + for (velox::vector_size_t i = 0; i < numRows; ++i) { + if (inputFlat->isNullAt(i)) { + outFlat->setNull(i, true); + } else { + int32_t stringLen = inputFlat->valueAt(i).size(); + outFlat->set(i, stringLen); + } + } + } +}; + +} // namespace facebook::presto::functions::rest diff --git a/presto-native-execution/presto_cpp/main/functions/remote/utils/ContentTypes.h b/presto-native-execution/presto_cpp/main/functions/remote/utils/ContentTypes.h new file mode 100644 index 000000000000..4200ea74fb67 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/functions/remote/utils/ContentTypes.h @@ -0,0 +1,22 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace facebook::presto::functions::remote { +inline constexpr const char* CONTENT_TYPE_SPARK_UNSAFE_ROW = + "application/X-spark-unsafe-row"; +inline constexpr const char* CONTENT_TYPE_PRESTO_PAGE = + "application/X-presto-pages"; +} // namespace facebook::presto::functions::remote diff --git a/presto-native-execution/presto_cpp/main/http/HttpConstants.h b/presto-native-execution/presto_cpp/main/http/HttpConstants.h index b6e3c24909a9..bc91594efce7 100644 --- a/presto-native-execution/presto_cpp/main/http/HttpConstants.h +++ b/presto-native-execution/presto_cpp/main/http/HttpConstants.h @@ -17,6 +17,7 @@ namespace facebook::presto::http { const uint16_t kHttpOk = 200; const uint16_t kHttpAccepted = 202; const uint16_t kHttpNoContent = 204; +const uint16_t kHttpMultipleChoices = 300; const uint16_t kHttpBadRequest = 400; const uint16_t kHttpUnauthorized = 401; const uint16_t kHttpNotFound = 404; diff --git a/presto-native-execution/presto_cpp/main/types/CMakeLists.txt b/presto-native-execution/presto_cpp/main/types/CMakeLists.txt index 8f11b6308353..b58b1ac7a7b8 100644 --- a/presto-native-execution/presto_cpp/main/types/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/types/CMakeLists.txt @@ -16,6 +16,10 @@ target_link_libraries(presto_type_converter velox_presto_type_parser) add_library(presto_velox_expr_conversion OBJECT PrestoToVeloxExpr.cpp) target_link_libraries(presto_velox_expr_conversion velox_presto_types velox_vector velox_exception) +if(PRESTO_ENABLE_REMOTE_FUNCTIONS) + target_link_libraries(presto_velox_expr_conversion presto_to_velox_remote_functions) +endif() + add_library(presto_types PrestoToVeloxQueryPlan.cpp VeloxPlanValidator.cpp PrestoToVeloxSplit.cpp) target_link_libraries( presto_types diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp index b83f12ca0e36..2aecd724cd77 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp @@ -14,7 +14,7 @@ #include "presto_cpp/main/types/PrestoToVeloxExpr.h" #include -#include +#include #include "presto_cpp/main/common/Configs.h" #include "presto_cpp/main/common/Utils.h" #include "presto_cpp/presto_protocol/Base64Util.h" @@ -23,6 +23,9 @@ #include "velox/vector/ComplexVector.h" #include "velox/vector/ConstantVector.h" #include "velox/vector/FlatVector.h" +#ifdef PRESTO_ENABLE_REMOTE_FUNCTIONS +#include "presto_cpp/main/functions/remote/PrestoRestFunctionRegistration.h" +#endif using namespace facebook::velox::core; using facebook::velox::TypeKind; @@ -516,6 +519,19 @@ TypedExprPtr VeloxExprConverter::toVeloxExpr( return std::make_shared( returnType, args, getFunctionName(sqlFunctionHandle->functionId)); } +#ifdef PRESTO_ENABLE_REMOTE_FUNCTIONS + else if ( + auto restFunctionHandle = + std::dynamic_pointer_cast( + pexpr.functionHandle)) { + auto args = toVeloxExpr(pexpr.arguments); + auto returnType = typeParser_->parse(pexpr.returnType); + + functions::remote::rest::registerRestRemoteFunction(*restFunctionHandle); + return std::make_shared( + returnType, args, getFunctionName(restFunctionHandle->functionId)); + } +#endif VELOX_FAIL("Unsupported function handle: {}", pexpr.functionHandle->_type); } diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h index f63a84ec35ad..bbd0ae54ee81 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h @@ -13,7 +13,6 @@ */ #pragma once -#include #include "presto_cpp/main/types/TypeParser.h" #include "presto_cpp/presto_protocol/core/presto_protocol_core.h" #include "velox/core/Expressions.h" diff --git a/presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt b/presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt index 0e1d27d1b18a..25ca93ca91be 100644 --- a/presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt @@ -57,33 +57,22 @@ target_link_libraries( GTest::gtest GTest::gtest_main presto_connectors - $ - $ - $ + presto_protocol + presto_type_converter + presto_types presto_operators presto_mutable_configs presto_type_test_utils presto_session_properties velox_core - velox_dwio_common_exception - velox_encode velox_exec velox_exec_test_lib velox_functions_prestosql - velox_functions_lib + velox_presto_type_parser velox_hive_connector velox_tpch_connector - velox_hive_partition_function - velox_presto_serializer - velox_presto_type_parser - velox_serialization velox_type Boost::filesystem - ${RE2} - ${FOLLY_WITH_DEPENDENCIES} - ${GLOG} - ${GFLAGS_LIBRARIES} - pthread ) set_property(TARGET presto_expressions_test PROPERTY JOB_POOL_LINK presto_link_job_pool) @@ -133,3 +122,25 @@ target_link_libraries( GTest::gtest GTest::gtest_main ) + +# RemoteFunctionHandleTest.cpp only contains tests cases related to +# RestFunctionHandle, therefore it is only enabled when remote functions are +# enabled. +if(PRESTO_ENABLE_REMOTE_FUNCTIONS) + add_executable(presto_to_rest_function_test RestFunctionHandleTest.cpp) + + add_test( + NAME presto_to_rest_function_test + COMMAND presto_to_rest_function_test + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + ) + + target_link_libraries( + presto_to_rest_function_test + presto_velox_expr_conversion + presto_protocol + presto_types + GTest::gtest + GTest::gtest_main + ) +endif() diff --git a/presto-native-execution/presto_cpp/main/types/tests/RestFunctionHandleTest.cpp b/presto-native-execution/presto_cpp/main/types/tests/RestFunctionHandleTest.cpp new file mode 100644 index 000000000000..f02bd9ba2c39 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/types/tests/RestFunctionHandleTest.cpp @@ -0,0 +1,451 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "presto_cpp/main/common/Configs.h" +#include "presto_cpp/main/types/PrestoToVeloxExpr.h" +#include "velox/functions/remote/client/Remote.h" +#include "velox/functions/remote/server/RemoteFunctionService.h" + +using namespace facebook::presto; +using namespace facebook::velox; + +class RestFunctionHandleTest : public ::testing::Test { + protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); + } + + void setupCallExpression() { + memoryPool_ = memory::MemoryManager::getInstance()->addLeafPool(); + converter_ = + std::make_unique(memoryPool_.get(), &typeParser_); + + expectedMetadata_.serdeFormat = functions::remote::PageFormat::PRESTO_PAGE; + + testExpr_.returnType = "bigint"; + testExpr_.displayName = "testFunction"; + auto cexpr = std::make_shared(); + cexpr->type = "bigint"; + cexpr->valueBlock.data = "CgAAAExPTkdfQVJSQVkBAAAAAAEAAAAAAAAA"; + testExpr_.arguments.push_back(cexpr); + + auto cexpr2 = std::make_shared(); + cexpr2->type = "bigint"; + cexpr2->valueBlock.data = "CgAAAExPTkdfQVJSQVkBAAAAAAEAAAAAAAAA"; + testExpr_.arguments.push_back(cexpr2); + } + + void SetUp() override { + auto restConfig = restSystemConfig(); + auto systemConfig = SystemConfig::instance(); + systemConfig->initialize(std::move(restConfig)); + setupCallExpression(); + } + + static std::unique_ptr restSystemConfig( + const std::unordered_map& configOverride = {}) { + std::unordered_map systemConfig{ + {std::string(SystemConfig::kRemoteFunctionServerSerde), + std::string("presto_page")}, + {std::string(SystemConfig::kRemoteFunctionServerRestURL), + std::string("http://localhost:8080")}}; + + for (const auto& [configName, configValue] : configOverride) { + systemConfig[configName] = configValue; + } + return std::make_unique(std::move(systemConfig), true); + } + + std::shared_ptr functionHandle_; + protocol::CallExpression testExpr_; + functions::RemoteVectorFunctionMetadata expectedMetadata_; + std::shared_ptr memoryPool_; + TypeParser typeParser_; + std::unique_ptr converter_; +}; + +TEST_F(RestFunctionHandleTest, parseRestFunctionHandle) { + try { + const std::string str = R"( + { + "@type": "RestFunctionHandle", + "functionId": "remote.testSchema.testFunction;BIGINT;BIGINT", + "version": "v1", + "signature": { + "name": "testFunction", + "kind": "SCALAR", + "returnType": "bigint", + "argumentTypes": ["bigint", "bigint"], + "typeVariableConstraints": [], + "longVariableConstraints": [], + "variableArity": false + } + } + )"; + const json j = json::parse(str); + const std::shared_ptr restFunctionHandle = j; + testExpr_.functionHandle = restFunctionHandle; + + auto expr = converter_->toVeloxExpr(testExpr_); + auto callExpr = std::dynamic_pointer_cast(expr); + ASSERT_NE(callExpr, nullptr); + EXPECT_EQ(callExpr->name(), "remote.testSchema.testFunction"); + + EXPECT_EQ(callExpr->inputs().size(), 2); + auto arg0 = std::dynamic_pointer_cast( + callExpr->inputs()[0]); + auto arg1 = std::dynamic_pointer_cast( + callExpr->inputs()[1]); + ASSERT_NE(arg0, nullptr); + ASSERT_NE(arg1, nullptr); + EXPECT_EQ(arg0->type()->kind(), TypeKind::BIGINT); + EXPECT_EQ(arg1->type()->kind(), TypeKind::BIGINT); + + } catch (const std::exception& e) { + FAIL() << "Exception: " << e.what(); + } +} + +TEST_F(RestFunctionHandleTest, parseRestFunctionHandleWithDecimalType) { + try { + const std::string str = R"JSON( +{ + "@type": "RestFunctionHandle", + "functionId": "remote.testSchema.decimalFunction;decimal(10,2);decimal(10,2)", + "version": "v1", + "signature": { + "name": "decimalFunction", + "kind": "SCALAR", + "returnType": "decimal(10,2)", + "argumentTypes": ["decimal(10,2)", "decimal(10,2)"], + "typeVariableConstraints": [], + "longVariableConstraints": [], + "variableArity": false + } +} +)JSON"; + + const json j = json::parse(str); + const std::shared_ptr restFunctionHandle = j; + + protocol::CallExpression callExpr; + callExpr.returnType = "decimal(10,2)"; + callExpr.displayName = "decimalFunction"; + callExpr.functionHandle = restFunctionHandle; + + auto cexpr = std::make_shared(); + cexpr->type = "decimal(10,2)"; + cexpr->valueBlock.data = "CgAAAExPTkdfQVJSQVkBAAAAAAEAAAAAAAAA"; + callExpr.arguments.push_back(cexpr); + + auto cexpr2 = std::make_shared(); + cexpr2->type = "decimal(10,2)"; + cexpr2->valueBlock.data = "CgAAAExPTkdfQVJSQVkBAAAAAAEAAAAAAAAA"; + callExpr.arguments.push_back(cexpr2); + + auto expr = converter_->toVeloxExpr(callExpr); + auto typedCallExpr = + std::dynamic_pointer_cast(expr); + ASSERT_NE(typedCallExpr, nullptr); + EXPECT_EQ(typedCallExpr->name(), "remote.testSchema.decimalFunction"); + EXPECT_EQ(typedCallExpr->type()->kind(), TypeKind::BIGINT); + + EXPECT_EQ(typedCallExpr->inputs().size(), 2); + auto arg0 = std::dynamic_pointer_cast( + typedCallExpr->inputs()[0]); + auto arg1 = std::dynamic_pointer_cast( + typedCallExpr->inputs()[1]); + ASSERT_NE(arg0, nullptr); + ASSERT_NE(arg1, nullptr); + EXPECT_EQ(arg0->type()->kind(), TypeKind::BIGINT); + EXPECT_EQ(arg1->type()->kind(), TypeKind::BIGINT); + + } catch (const std::exception& e) { + FAIL() << "Exception: " << e.what(); + } +} + +TEST_F(RestFunctionHandleTest, parseRestFunctionHandleWithArrayType) { + try { + const std::string str = R"JSON( + { + "@type": "RestFunctionHandle", + "functionId": "remote.testSchema.arrayFunction;array(bigint);array(varchar)", + "version": "v1", + "signature": { + "name": "arrayFunction", + "kind": "SCALAR", + "returnType": "array(bigint)", + "argumentTypes": ["array(bigint)", "array(varchar)"], + "typeVariableConstraints": [], + "longVariableConstraints": [], + "variableArity": false + } + } + )JSON"; + const json j = json::parse(str); + const std::shared_ptr restFunctionHandle = j; + + // Verify the signature parsing + ASSERT_NE(restFunctionHandle, nullptr); + EXPECT_EQ( + restFunctionHandle->functionId, + "remote.testSchema.arrayFunction;array(bigint);array(varchar)"); + EXPECT_EQ(restFunctionHandle->signature.name, "arrayFunction"); + EXPECT_EQ(restFunctionHandle->signature.returnType, "array(bigint)"); + EXPECT_EQ(restFunctionHandle->signature.argumentTypes.size(), 2); + EXPECT_EQ(restFunctionHandle->signature.argumentTypes[0], "array(bigint)"); + EXPECT_EQ(restFunctionHandle->signature.argumentTypes[1], "array(varchar)"); + + // Verify type parsing + auto returnType = + typeParser_.parse(restFunctionHandle->signature.returnType); + EXPECT_EQ(returnType->kind(), TypeKind::ARRAY); + auto returnArrayType = + std::dynamic_pointer_cast(returnType); + ASSERT_NE(returnArrayType, nullptr); + EXPECT_EQ(returnArrayType->elementType()->kind(), TypeKind::BIGINT); + + auto argType0 = + typeParser_.parse(restFunctionHandle->signature.argumentTypes[0]); + EXPECT_EQ(argType0->kind(), TypeKind::ARRAY); + auto argArrayType0 = std::dynamic_pointer_cast(argType0); + ASSERT_NE(argArrayType0, nullptr); + EXPECT_EQ(argArrayType0->elementType()->kind(), TypeKind::BIGINT); + + auto argType1 = + typeParser_.parse(restFunctionHandle->signature.argumentTypes[1]); + EXPECT_EQ(argType1->kind(), TypeKind::ARRAY); + auto argArrayType1 = std::dynamic_pointer_cast(argType1); + ASSERT_NE(argArrayType1, nullptr); + EXPECT_EQ(argArrayType1->elementType()->kind(), TypeKind::VARCHAR); + + } catch (const std::exception& e) { + FAIL() << "Exception: " << e.what(); + } +} + +TEST_F(RestFunctionHandleTest, parseRestFunctionHandleWithMapType) { + try { + const std::string str = R"JSON( + { + "@type": "RestFunctionHandle", + "functionId": "remote.testSchema.mapFunction;map(varchar,bigint);map(bigint,double)", + "version": "v1", + "signature": { + "name": "mapFunction", + "kind": "SCALAR", + "returnType": "map(varchar,bigint)", + "argumentTypes": ["map(varchar,bigint)", "map(bigint,double)"], + "typeVariableConstraints": [], + "longVariableConstraints": [], + "variableArity": false + } + } + )JSON"; + const json j = json::parse(str); + const std::shared_ptr restFunctionHandle = j; + + // Verify the signature parsing + ASSERT_NE(restFunctionHandle, nullptr); + EXPECT_EQ( + restFunctionHandle->functionId, + "remote.testSchema.mapFunction;map(varchar,bigint);map(bigint,double)"); + EXPECT_EQ(restFunctionHandle->signature.name, "mapFunction"); + EXPECT_EQ(restFunctionHandle->signature.returnType, "map(varchar,bigint)"); + EXPECT_EQ(restFunctionHandle->signature.argumentTypes.size(), 2); + EXPECT_EQ( + restFunctionHandle->signature.argumentTypes[0], "map(varchar,bigint)"); + EXPECT_EQ( + restFunctionHandle->signature.argumentTypes[1], "map(bigint,double)"); + + // Verify type parsing + auto returnType = + typeParser_.parse(restFunctionHandle->signature.returnType); + EXPECT_EQ(returnType->kind(), TypeKind::MAP); + auto returnMapType = std::dynamic_pointer_cast(returnType); + ASSERT_NE(returnMapType, nullptr); + EXPECT_EQ(returnMapType->keyType()->kind(), TypeKind::VARCHAR); + EXPECT_EQ(returnMapType->valueType()->kind(), TypeKind::BIGINT); + + auto argType0 = + typeParser_.parse(restFunctionHandle->signature.argumentTypes[0]); + EXPECT_EQ(argType0->kind(), TypeKind::MAP); + auto argMapType0 = std::dynamic_pointer_cast(argType0); + ASSERT_NE(argMapType0, nullptr); + EXPECT_EQ(argMapType0->keyType()->kind(), TypeKind::VARCHAR); + EXPECT_EQ(argMapType0->valueType()->kind(), TypeKind::BIGINT); + + auto argType1 = + typeParser_.parse(restFunctionHandle->signature.argumentTypes[1]); + EXPECT_EQ(argType1->kind(), TypeKind::MAP); + auto argMapType1 = std::dynamic_pointer_cast(argType1); + ASSERT_NE(argMapType1, nullptr); + EXPECT_EQ(argMapType1->keyType()->kind(), TypeKind::BIGINT); + EXPECT_EQ(argMapType1->valueType()->kind(), TypeKind::DOUBLE); + + } catch (const std::exception& e) { + FAIL() << "Exception: " << e.what(); + } +} + +TEST_F(RestFunctionHandleTest, parseRestFunctionHandleWithRowType) { + try { + const std::string str = R"JSON( + { + "@type": "RestFunctionHandle", + "functionId": "remote.testSchema.rowFunction;row(bigint,varchar);row(double,boolean)", + "version": "v1", + "signature": { + "name": "rowFunction", + "kind": "SCALAR", + "returnType": "row(bigint,varchar)", + "argumentTypes": ["row(bigint,varchar)", "row(double,boolean)"], + "typeVariableConstraints": [], + "longVariableConstraints": [], + "variableArity": false + } + } + )JSON"; + const json j = json::parse(str); + const std::shared_ptr restFunctionHandle = j; + + // Verify the signature parsing + ASSERT_NE(restFunctionHandle, nullptr); + EXPECT_EQ( + restFunctionHandle->functionId, + "remote.testSchema.rowFunction;row(bigint,varchar);row(double,boolean)"); + EXPECT_EQ(restFunctionHandle->signature.name, "rowFunction"); + EXPECT_EQ(restFunctionHandle->signature.returnType, "row(bigint,varchar)"); + EXPECT_EQ(restFunctionHandle->signature.argumentTypes.size(), 2); + EXPECT_EQ( + restFunctionHandle->signature.argumentTypes[0], "row(bigint,varchar)"); + EXPECT_EQ( + restFunctionHandle->signature.argumentTypes[1], "row(double,boolean)"); + + // Verify type parsing + auto returnType = + typeParser_.parse(restFunctionHandle->signature.returnType); + EXPECT_EQ(returnType->kind(), TypeKind::ROW); + auto returnRowType = std::dynamic_pointer_cast(returnType); + ASSERT_NE(returnRowType, nullptr); + EXPECT_EQ(returnRowType->size(), 2); + EXPECT_EQ(returnRowType->childAt(0)->kind(), TypeKind::BIGINT); + EXPECT_EQ(returnRowType->childAt(1)->kind(), TypeKind::VARCHAR); + + auto argType0 = + typeParser_.parse(restFunctionHandle->signature.argumentTypes[0]); + EXPECT_EQ(argType0->kind(), TypeKind::ROW); + auto argRowType0 = std::dynamic_pointer_cast(argType0); + ASSERT_NE(argRowType0, nullptr); + EXPECT_EQ(argRowType0->size(), 2); + EXPECT_EQ(argRowType0->childAt(0)->kind(), TypeKind::BIGINT); + EXPECT_EQ(argRowType0->childAt(1)->kind(), TypeKind::VARCHAR); + + auto argType1 = + typeParser_.parse(restFunctionHandle->signature.argumentTypes[1]); + EXPECT_EQ(argType1->kind(), TypeKind::ROW); + auto argRowType1 = std::dynamic_pointer_cast(argType1); + ASSERT_NE(argRowType1, nullptr); + EXPECT_EQ(argRowType1->size(), 2); + EXPECT_EQ(argRowType1->childAt(0)->kind(), TypeKind::DOUBLE); + EXPECT_EQ(argRowType1->childAt(1)->kind(), TypeKind::BOOLEAN); + + } catch (const std::exception& e) { + FAIL() << "Exception: " << e.what(); + } +} + +TEST_F(RestFunctionHandleTest, parseRestFunctionHandleWithNestedComplexTypes) { + try { + const std::string str = R"JSON( + { + "@type": "RestFunctionHandle", + "functionId": "remote.testSchema.nestedFunction;array(map(varchar,bigint));row(array(decimal(10,2)),map(bigint,varchar))", + "version": "v1", + "signature": { + "name": "nestedFunction", + "kind": "SCALAR", + "returnType": "map(varchar,array(bigint))", + "argumentTypes": ["array(map(varchar,bigint))", "row(array(decimal(10,2)),map(bigint,varchar))"], + "typeVariableConstraints": [], + "longVariableConstraints": [], + "variableArity": false + } + } + )JSON"; + const json j = json::parse(str); + const std::shared_ptr restFunctionHandle = j; + + // Verify the signature parsing + ASSERT_NE(restFunctionHandle, nullptr); + EXPECT_EQ(restFunctionHandle->signature.name, "nestedFunction"); + EXPECT_EQ( + restFunctionHandle->signature.returnType, "map(varchar,array(bigint))"); + EXPECT_EQ(restFunctionHandle->signature.argumentTypes.size(), 2); + + // Verify return type: map(varchar,array(bigint)) + auto returnType = + typeParser_.parse(restFunctionHandle->signature.returnType); + EXPECT_EQ(returnType->kind(), TypeKind::MAP); + auto returnMapType = std::dynamic_pointer_cast(returnType); + ASSERT_NE(returnMapType, nullptr); + EXPECT_EQ(returnMapType->keyType()->kind(), TypeKind::VARCHAR); + EXPECT_EQ(returnMapType->valueType()->kind(), TypeKind::ARRAY); + auto valueArrayType = + std::dynamic_pointer_cast(returnMapType->valueType()); + ASSERT_NE(valueArrayType, nullptr); + EXPECT_EQ(valueArrayType->elementType()->kind(), TypeKind::BIGINT); + + // Verify arg0 type: array(map(varchar,bigint)) + auto argType0 = + typeParser_.parse(restFunctionHandle->signature.argumentTypes[0]); + EXPECT_EQ(argType0->kind(), TypeKind::ARRAY); + auto argArrayType = std::dynamic_pointer_cast(argType0); + ASSERT_NE(argArrayType, nullptr); + EXPECT_EQ(argArrayType->elementType()->kind(), TypeKind::MAP); + auto elementMapType = + std::dynamic_pointer_cast(argArrayType->elementType()); + ASSERT_NE(elementMapType, nullptr); + EXPECT_EQ(elementMapType->keyType()->kind(), TypeKind::VARCHAR); + EXPECT_EQ(elementMapType->valueType()->kind(), TypeKind::BIGINT); + + // Verify arg1 type: row(array(decimal(10,2)),map(bigint,varchar)) + auto argType1 = + typeParser_.parse(restFunctionHandle->signature.argumentTypes[1]); + EXPECT_EQ(argType1->kind(), TypeKind::ROW); + auto argRowType = std::dynamic_pointer_cast(argType1); + ASSERT_NE(argRowType, nullptr); + EXPECT_EQ(argRowType->size(), 2); + + // First child: array(decimal(10,2)) + EXPECT_EQ(argRowType->childAt(0)->kind(), TypeKind::ARRAY); + auto childArrayType = + std::dynamic_pointer_cast(argRowType->childAt(0)); + ASSERT_NE(childArrayType, nullptr); + EXPECT_EQ(childArrayType->elementType()->kind(), TypeKind::BIGINT); + + // Second child: map(bigint,varchar) + EXPECT_EQ(argRowType->childAt(1)->kind(), TypeKind::MAP); + auto childMapType = + std::dynamic_pointer_cast(argRowType->childAt(1)); + ASSERT_NE(childMapType, nullptr); + EXPECT_EQ(childMapType->keyType()->kind(), TypeKind::BIGINT); + EXPECT_EQ(childMapType->valueType()->kind(), TypeKind::VARCHAR); + + } catch (const std::exception& e) { + FAIL() << "Exception: " << e.what(); + } +} diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/ContainerQueryRunner.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/ContainerQueryRunner.java index 99da6faace55..b07e6112c7b5 100644 --- a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/ContainerQueryRunner.java +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/ContainerQueryRunner.java @@ -28,6 +28,7 @@ import com.facebook.presto.sql.planner.NodePartitioningManager; import com.facebook.presto.sql.planner.sanity.PlanCheckerProviderManager; import com.facebook.presto.testing.MaterializedResult; +import com.facebook.presto.testing.MaterializedRow; import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.testing.TestingAccessControlManager; import com.facebook.presto.transaction.TransactionManager; @@ -38,6 +39,7 @@ import java.io.IOException; import java.sql.Connection; +import java.sql.DriverManager; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; @@ -50,46 +52,57 @@ import java.util.logging.Logger; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; -import static java.sql.DriverManager.getConnection; public class ContainerQueryRunner implements QueryRunner { - private static final Network network = Network.newNetwork(); - private static final String PRESTO_COORDINATOR_IMAGE = System.getProperty("coordinatorImage", "presto-coordinator:latest"); - private static final String PRESTO_WORKER_IMAGE = System.getProperty("workerImage", "presto-worker:latest"); - private static final String CONTAINER_TIMEOUT = System.getProperty("containerTimeout", "120"); - private static final String CLUSTER_SHUTDOWN_TIMEOUT = System.getProperty("clusterShutDownTimeout", "10"); - private static final String BASE_DIR = System.getProperty("user.dir"); - private static final int DEFAULT_COORDINATOR_PORT = 8080; - private static final String TPCH_CATALOG = "tpch"; - private static final String TINY_SCHEMA = "tiny"; - private static final int DEFAULT_NUMBER_OF_WORKERS = 4; - private static final Logger logger = Logger.getLogger(ContainerQueryRunner.class.getName()); - private final GenericContainer coordinator; - private final List> workers = new ArrayList<>(); - private final int coordinatorPort; - private final String catalog; - private final String schema; - private final int numberOfWorkers; - private Connection connection; + protected static final Network network = Network.newNetwork(); + protected static final String PRESTO_COORDINATOR_IMAGE = System.getProperty("coordinatorImage", "presto-coordinator:latest"); + protected static final String PRESTO_WORKER_IMAGE = System.getProperty("workerImage", "presto-worker:latest"); + protected static final String CONTAINER_TIMEOUT = System.getProperty("containerTimeout", "120"); + protected static final String CLUSTER_SHUTDOWN_TIMEOUT = System.getProperty("clusterShutDownTimeout", "10"); + protected static final String BASE_DIR = System.getProperty("user.dir"); + protected static final int DEFAULT_COORDINATOR_PORT = 8080; + protected static final int DEFAULT_FUNCTION_SERVER_PORT = 1122; + protected static final String TPCH_CATALOG = "tpch"; + protected static final String TINY_SCHEMA = "tiny"; + protected static final int DEFAULT_NUMBER_OF_WORKERS = 4; + + protected static final Logger logger = Logger.getLogger(ContainerQueryRunner.class.getName()); + + protected final GenericContainer coordinator; + protected final List> workers = new ArrayList<>(); + protected final int coordinatorPort; + protected final String catalog; + protected final String schema; + protected GenericContainer functionServer; + protected int functionServerPort; + protected boolean enableFunctionServer; + protected Connection connection; public ContainerQueryRunner() throws InterruptedException, IOException { - this(DEFAULT_COORDINATOR_PORT, TPCH_CATALOG, TINY_SCHEMA, DEFAULT_NUMBER_OF_WORKERS); + this(DEFAULT_COORDINATOR_PORT, TPCH_CATALOG, TINY_SCHEMA, DEFAULT_NUMBER_OF_WORKERS, DEFAULT_FUNCTION_SERVER_PORT, false); } - public ContainerQueryRunner(int coordinatorPort, String catalog, String schema, int numberOfWorkers) + public ContainerQueryRunner(int coordinatorPort, String catalog, String schema, int numberOfWorkers, int functionServerPort, boolean enableFunctionServer) throws InterruptedException, IOException { this.coordinatorPort = coordinatorPort; this.catalog = catalog; this.schema = schema; - this.numberOfWorkers = numberOfWorkers; + this.functionServerPort = functionServerPort; + this.enableFunctionServer = enableFunctionServer; + + // Start function server first if enabled + if (enableFunctionServer) { + this.functionServer = createFunctionServer(); + this.functionServer.start(); + logger.info("Presto function server is deployed at http://" + functionServer.getHost() + ":" + functionServer.getMappedPort(functionServerPort)); + } - // The container details can be added as properties in VM options for testing in IntelliJ. - coordinator = createCoordinator(); + this.coordinator = createCoordinator(); for (int i = 0; i < numberOfWorkers; i++) { workers.add(createNativeWorker(7777 + i, "native-worker-" + i)); } @@ -107,10 +120,10 @@ public ContainerQueryRunner(int coordinatorPort, String catalog, String schema, coordinator.getMappedPort(coordinatorPort), catalog, schema, - "timeZoneId=UTC"); + enableFunctionServer ? "timeZoneId=UTC&sessionProperties=remote_functions_enabled:true" : "timeZoneId=UTC"); try { - connection = getConnection(url, "test", null); + this.connection = DriverManager.getConnection(url, "test", null); } catch (SQLException e) { throw new RuntimeException(e); @@ -118,12 +131,16 @@ public ContainerQueryRunner(int coordinatorPort, String catalog, String schema, // Delete the temporary files once the containers are started. ContainerQueryRunnerUtils.deleteDirectory(BASE_DIR + "/testcontainers/coordinator"); - for (int i = 0; i < numberOfWorkers; i++) { - ContainerQueryRunnerUtils.deleteDirectory(BASE_DIR + "/testcontainers/native-worker-" + i); + for (GenericContainer worker : workers) { + String alias = worker.getNetworkAliases().get(1); + ContainerQueryRunnerUtils.deleteDirectory(BASE_DIR + "/testcontainers/" + alias); + } + if (enableFunctionServer) { + ContainerQueryRunnerUtils.deleteDirectory(BASE_DIR + "/testcontainers/function-server"); } } - private GenericContainer createCoordinator() + protected GenericContainer createCoordinator() throws IOException { ContainerQueryRunnerUtils.createCoordinatorTpchProperties(); @@ -132,22 +149,25 @@ private GenericContainer createCoordinator() ContainerQueryRunnerUtils.createCoordinatorJvmConfig(); ContainerQueryRunnerUtils.createCoordinatorLogProperties(); ContainerQueryRunnerUtils.createCoordinatorNodeProperties(); - ContainerQueryRunnerUtils.createCoordinatorEntryPointScript(); + ContainerQueryRunnerUtils.createCoordinatorEntryPointScript(); // Never run function server in coordinator + if (enableFunctionServer) { + ContainerQueryRunnerUtils.createRestRemoteProperties(functionServerPort); + } return new GenericContainer<>(PRESTO_COORDINATOR_IMAGE) - .withExposedPorts(coordinatorPort) .withNetwork(network) .withNetworkAliases("presto-coordinator") .withCopyFileToContainer(MountableFile.forHostPath(BASE_DIR + "/testcontainers/coordinator/etc"), "/opt/presto-server/etc") .withCopyFileToContainer(MountableFile.forHostPath(BASE_DIR + "/testcontainers/coordinator/entrypoint.sh"), "/opt/entrypoint.sh") .waitingFor(Wait.forLogMessage(".*======== SERVER STARTED ========.*", 1)) - .withStartupTimeout(Duration.ofSeconds(Long.parseLong(CONTAINER_TIMEOUT))); + .withStartupTimeout(Duration.ofSeconds(Long.parseLong(CONTAINER_TIMEOUT))) + .withExposedPorts(coordinatorPort); } - private GenericContainer createNativeWorker(int port, String nodeId) + protected GenericContainer createNativeWorker(int port, String nodeId) throws IOException { - ContainerQueryRunnerUtils.createNativeWorkerConfigProperties(coordinatorPort, nodeId); + ContainerQueryRunnerUtils.createNativeWorkerConfigPropertiesWithFunctionServer(coordinatorPort, functionServerPort, nodeId); ContainerQueryRunnerUtils.createNativeWorkerTpchProperties(nodeId); ContainerQueryRunnerUtils.createNativeWorkerEntryPointScript(nodeId); ContainerQueryRunnerUtils.createNativeWorkerNodeProperties(nodeId); @@ -160,6 +180,23 @@ private GenericContainer createNativeWorker(int port, String nodeId) .waitingFor(Wait.forLogMessage(".*Announcement succeeded: HTTP 202.*", 1)); } + protected GenericContainer createFunctionServer() + throws IOException + { + ContainerQueryRunnerUtils.createFunctionServerConfigProperties(functionServerPort); + ContainerQueryRunnerUtils.createFunctionServerEntryPointScript(); + + // Reuse the coordinator image since it already contains the function server jar + return new GenericContainer<>(PRESTO_COORDINATOR_IMAGE) + .withNetwork(network) + .withNetworkAliases("presto-remote-function-server") + .withCopyFileToContainer(MountableFile.forHostPath(BASE_DIR + "/testcontainers/function-server/etc"), "/opt/function-server/etc") + .withCopyFileToContainer(MountableFile.forHostPath(BASE_DIR + "/testcontainers/function-server/entrypoint.sh"), "/opt/entrypoint.sh") + .waitingFor(Wait.forLogMessage(".*======== REMOTE FUNCTION SERVER STARTED at: .*", 1)) + .withStartupTimeout(Duration.ofSeconds(Long.parseLong(CONTAINER_TIMEOUT))) + .withExposedPorts(functionServerPort); + } + @Override public void close() { @@ -171,6 +208,9 @@ public void close() } coordinator.stop(); workers.forEach(GenericContainer::stop); + if (functionServer != null) { + functionServer.stop(); + } } @Override @@ -248,7 +288,27 @@ public MaterializedResult execute(String sql) @Override public MaterializedResult execute(Session session, String sql, List resultTypes) { - throw new UnsupportedOperationException(); + // Added logic similar to H2QueryRunner. + try { + Statement statement = connection.createStatement(); + ResultSet resultSet = statement.executeQuery(sql); + MaterializedResult rawResult = ContainerQueryRunnerUtils.toMaterializedResult(resultSet); + + // Coerce the raw result to the requested resultTypes + List coercedRows = new ArrayList<>(); + for (MaterializedRow row : rawResult.getMaterializedRows()) { + List coercedValues = new ArrayList<>(); + for (int i = 0; i < resultTypes.size(); i++) { + Object value = row.getField(i); + coercedValues.add(value); + } + coercedRows.add(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, coercedValues)); + } + return new MaterializedResult(coercedRows, resultTypes); + } + catch (SQLException e) { + throw new RuntimeException("Error executing query: " + sql, e); + } } @Override diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/ContainerQueryRunnerUtils.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/ContainerQueryRunnerUtils.java index 2044ae913a71..1ca1beedeb93 100644 --- a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/ContainerQueryRunnerUtils.java +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/ContainerQueryRunnerUtils.java @@ -82,7 +82,7 @@ public static void createNativeWorkerTpchProperties(String nodeId) createPropertiesFile("testcontainers/" + nodeId + "/etc/catalog/tpch.properties", properties); } - public static void createNativeWorkerConfigProperties(int coordinatorPort, String nodeId) + public static void createNativeWorkerConfigPropertiesWithFunctionServer(int coordinatorPort, int functionServerPort, String nodeId) throws IOException { Properties properties = new Properties(); @@ -90,7 +90,7 @@ public static void createNativeWorkerConfigProperties(int coordinatorPort, Strin properties.setProperty("http-server.http.port", "7777"); properties.setProperty("discovery.uri", "http://presto-coordinator:" + coordinatorPort); properties.setProperty("system-memory-gb", "2"); - properties.setProperty("native.sidecar", "false"); + properties.setProperty("remote-function-server.rest.url", "http://presto-remote-function-server:" + functionServerPort); createPropertiesFile("testcontainers/" + nodeId + "/etc/config.properties", properties); } @@ -104,6 +104,8 @@ public static void createCoordinatorConfigProperties(int port) properties.setProperty("http-server.http.port", Integer.toString(port)); properties.setProperty("discovery-server.enabled", "true"); properties.setProperty("discovery.uri", "http://presto-coordinator:" + port); + properties.setProperty("list-built-in-functions-only", "false"); + properties.setProperty("native-execution-enabled", "true"); // Get native worker system properties and add them to the coordinator properties Map nativeWorkerProperties = NativeQueryRunnerUtils.getNativeWorkerSystemProperties(); @@ -114,9 +116,37 @@ public static void createCoordinatorConfigProperties(int port) createPropertiesFile("testcontainers/coordinator/etc/config.properties", properties); } - public static void createCoordinatorJvmConfig() + public static void createRestRemoteProperties(int functionServerPort) throws IOException + { + Properties properties = new Properties(); + properties.setProperty("function-namespace-manager.name", "rest"); + properties.setProperty("supported-function-languages", "Java"); + properties.setProperty("function-implementation-type", "REST"); + properties.setProperty("rest-based-function-manager.rest.url", "http://presto-remote-function-server:" + functionServerPort); + + String directoryPath = "testcontainers/function-namespace"; + File directory = new File(directoryPath); + if (!directory.exists()) { + directory.mkdirs(); + } + createPropertiesFile("testcontainers/coordinator/etc/function-namespace/remote.properties", properties); + } + + public static void createFunctionServerConfigProperties(int functionServerPort) + throws IOException + { + Properties properties = new Properties(); + properties.setProperty("http-server.http.port", String.valueOf(functionServerPort)); + properties.setProperty("regex-library", "RE2J"); + properties.setProperty("parse-decimal-literals-as-double", "true"); + + createPropertiesFile("testcontainers/function-server/etc/config.properties", properties); + } + + public static void createCoordinatorJvmConfig() + throws IOException { String jvmConfig = "-server\n" + "-Xmx1G\n" + @@ -164,10 +194,27 @@ public static void createCoordinatorEntryPointScript() { String scriptContent = "#!/bin/sh\n" + "set -e\n" + - "$PRESTO_HOME/bin/launcher run\n"; + "trap 'kill -TERM $app 2>/dev/null' TERM\n" + + "$PRESTO_HOME/bin/launcher run &\n" + + "app=$!\n" + + "wait $app"; createScriptFile("testcontainers/coordinator/entrypoint.sh", scriptContent); } + public static void createFunctionServerEntryPointScript() + throws IOException + { + String scriptContent = "#!/bin/sh\n" + + "set -e\n" + + "trap 'kill -TERM $app 2>/dev/null' TERM\n" + + "java -Dconfig=/opt/function-server/etc/config.properties " + + "-jar /opt/presto-remote-function-server &\n" + + "app=$!\n" + + "wait $app\n"; + + createScriptFile("testcontainers/function-server/entrypoint.sh", scriptContent); + } + public static void createNativeWorkerEntryPointScript(String nodeId) throws IOException { diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/TestPrestoContainerRemoteFunction.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/TestPrestoContainerRemoteFunction.java new file mode 100644 index 000000000000..a0d3359ac242 --- /dev/null +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/TestPrestoContainerRemoteFunction.java @@ -0,0 +1,84 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.nativeworker; + +import com.facebook.presto.tests.AbstractTestQueryFramework; +import org.testng.annotations.Test; + +import static org.testng.Assert.assertEquals; + +/** + * These tests call remote functions served by the Presto Function Server + * (implementation: {@link com.facebook.presto.server.FunctionServer}). + */ + +public class TestPrestoContainerRemoteFunction + extends AbstractTestQueryFramework +{ + @Override + protected ContainerQueryRunner createQueryRunner() + throws Exception + { + return new ContainerQueryRunner( + ContainerQueryRunner.DEFAULT_COORDINATOR_PORT, + ContainerQueryRunner.TPCH_CATALOG, + ContainerQueryRunner.TINY_SCHEMA, + ContainerQueryRunner.DEFAULT_NUMBER_OF_WORKERS, + ContainerQueryRunner.DEFAULT_FUNCTION_SERVER_PORT, + true); + } + + @Test + public void testRemoteBasicTests() + { + assertEquals( + computeActual("select remote.default.abs(-10)") + .getMaterializedRows().get(0).getField(0).toString(), + "10"); + assertEquals( + computeActual("select remote.default.abs(-1230)") + .getMaterializedRows().get(0).getField(0).toString(), + "1230"); + assertEquals( + computeActual("select remote.default.day(interval '2' day)") + .getMaterializedRows().get(0).getField(0).toString(), + "2"); + assertEquals( + computeActual("select remote.default.length(CAST('AB' AS VARBINARY))") + .getMaterializedRows().get(0).getField(0).toString(), + "2"); + assertEquals( + computeActual("select remote.default.floor(100000.99)") + .getMaterializedRows().get(0).getField(0).toString(), + "100000.0"); + assertEquals( + computeActual("select remote.default.to_base32(CAST('abc' AS VARBINARY))") + .getMaterializedRows().get(0).getField(0).toString(), + "MFRGG==="); + } + + @Test + public void testRemoteFunctionAppliedToColumn() + { + assertEquals(computeActual("SELECT remote.default.floor(o_totalprice) FROM tpch.sf1.orders") + .getMaterializedRows().size(), 1500000); + assertEquals(computeActual("SELECT remote.default.abs(l_discount) FROM tpch.sf1.lineitem") + .getMaterializedRows().size(), 6001215); + assertQueryWithSameQueryRunner( + "SELECT remote.default.abs(l_discount) FROM tpch.sf1.lineitem", + "SELECT abs(l_discount) FROM tpch.sf1.lineitem"); + assertEquals(computeActual("SELECT remote.default.length(CAST(o_comment AS VARBINARY)) FROM tpch.sf1.orders") + .getMaterializedRows().size(), 1500000); + } +}