diff --git a/.gitignore b/.gitignore
index a4512f9f794d..bf975bab3ff0 100644
--- a/.gitignore
+++ b/.gitignore
@@ -66,3 +66,4 @@ presto-native-execution/deps-install
# Compiled executables used for docker build
/docker/presto-cli-*-executable.jar
/docker/presto-server-*.tar.gz
+/docker/presto-function-server-executable.jar
diff --git a/docker/Dockerfile b/docker/Dockerfile
index d44de58d549b..4ce01b8818d7 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -3,6 +3,7 @@ FROM quay.io/centos/centos:stream9
ARG PRESTO_VERSION
ARG PRESTO_PKG=presto-server-$PRESTO_VERSION.tar.gz
ARG PRESTO_CLI_JAR=presto-cli-$PRESTO_VERSION-executable.jar
+ARG PRESTO_REMOTE_SERVER_JAR=presto-function-server-executable.jar
ARG JMX_PROMETHEUS_JAVAAGENT_VERSION=0.20.0
ENV PRESTO_HOME="/opt/presto-server"
@@ -13,7 +14,8 @@ RUN --mount=type=cache,target=/var/cache/dnf,sharing=locked \
# clean cache jobs
&& mv /etc/yum/protected.d/systemd.conf /etc/yum/protected.d/systemd.conf.bak
-COPY --chmod=755 $PRESTO_CLI_JAR /opt/presto-cli
+COPY --chmod=755 $PRESTO_CLI_JAR /opt/presto-cli
+COPY --chmod=755 $PRESTO_REMOTE_SERVER_JAR /opt/presto-remote-function-server
RUN --mount=type=bind,source=$PRESTO_PKG,target=/$PRESTO_PKG \
# Download Presto and move \
diff --git a/presto-native-execution/.dockerignore b/presto-native-execution/.dockerignore
new file mode 100644
index 000000000000..b25c310f8162
--- /dev/null
+++ b/presto-native-execution/.dockerignore
@@ -0,0 +1,4 @@
+# Ignore build directories
+_build/
+cmake-build-debug/
+cmake-build-release/
diff --git a/presto-native-execution/CMakeLists.txt b/presto-native-execution/CMakeLists.txt
index d7becdc6188a..fc5d1ddf6269 100644
--- a/presto-native-execution/CMakeLists.txt
+++ b/presto-native-execution/CMakeLists.txt
@@ -136,7 +136,7 @@ set(Boost_USE_MULTITHREADED TRUE)
find_package(
Boost
1.66.0
- REQUIRED program_options context filesystem regex thread system date_time atomic
+ REQUIRED program_options context filesystem regex thread system date_time url atomic
)
include_directories(SYSTEM ${Boost_INCLUDE_DIRS})
diff --git a/presto-native-execution/pom.xml b/presto-native-execution/pom.xml
index 9eeb4b10250e..34c0fcb10701 100644
--- a/presto-native-execution/pom.xml
+++ b/presto-native-execution/pom.xml
@@ -503,6 +503,12 @@
presto-cli-*-executable.jar
+
+ ${project.parent.basedir}/presto-function-server/target
+
+ presto-function-server-executable.jar
+
+
${project.parent.basedir}/presto-server/target
@@ -574,7 +580,7 @@
Release
presto-native-dependency:latest
- -DPRESTO_ENABLE_TESTING=OFF
+ "-DPRESTO_ENABLE_TESTING=OFF -DPRESTO_ENABLE_REMOTE_FUNCTIONS=ON"
2
ubuntu:22.04
ubuntu
diff --git a/presto-native-execution/presto_cpp/main/PrestoServer.cpp b/presto-native-execution/presto_cpp/main/PrestoServer.cpp
index 11811b5923ba..9fcf03384215 100644
--- a/presto-native-execution/presto_cpp/main/PrestoServer.cpp
+++ b/presto-native-execution/presto_cpp/main/PrestoServer.cpp
@@ -1391,11 +1391,12 @@ void PrestoServer::registerRemoteFunctions() {
<< catalogName << "' catalog.";
} else {
VELOX_FAIL(
- "To register remote functions using a json file path you need to "
- "specify the remote server location using '{}', '{}' or '{}'.",
+ "To register remote functions you need to specify the remote server "
+ "location using '{}', '{}' or '{}' or {}.",
SystemConfig::kRemoteFunctionServerThriftAddress,
SystemConfig::kRemoteFunctionServerThriftPort,
- SystemConfig::kRemoteFunctionServerThriftUdsPath);
+ SystemConfig::kRemoteFunctionServerThriftUdsPath,
+ SystemConfig::kRemoteFunctionServerRestURL);
}
}
#endif
diff --git a/presto-native-execution/presto_cpp/main/common/Configs.cpp b/presto-native-execution/presto_cpp/main/common/Configs.cpp
index 80df7602f3f0..18cd399fdee3 100644
--- a/presto-native-execution/presto_cpp/main/common/Configs.cpp
+++ b/presto-native-execution/presto_cpp/main/common/Configs.cpp
@@ -479,6 +479,10 @@ std::string SystemConfig::remoteFunctionServerSerde() const {
return optionalProperty(kRemoteFunctionServerSerde).value();
}
+std::string SystemConfig::remoteFunctionServerRestURL() const {
+ return optionalProperty(kRemoteFunctionServerRestURL).value();
+}
+
int32_t SystemConfig::maxDriversPerTask() const {
return optionalProperty(kMaxDriversPerTask).value();
}
diff --git a/presto-native-execution/presto_cpp/main/common/Configs.h b/presto-native-execution/presto_cpp/main/common/Configs.h
index b3419a69e67e..e7f5d8d5d785 100644
--- a/presto-native-execution/presto_cpp/main/common/Configs.h
+++ b/presto-native-execution/presto_cpp/main/common/Configs.h
@@ -763,6 +763,10 @@ class SystemConfig : public ConfigBase {
static constexpr std::string_view kRemoteFunctionServerThriftUdsPath{
"remote-function-server.thrift.uds-path"};
+ /// HTTP URL used by the remote function rest server.
+ static constexpr std::string_view kRemoteFunctionServerRestURL{
+ "remote-function-server.rest.url"};
+
/// Path where json files containing signatures for remote functions can be
/// found.
static constexpr std::string_view
@@ -924,6 +928,8 @@ class SystemConfig : public ConfigBase {
std::string remoteFunctionServerSerde() const;
+ std::string remoteFunctionServerRestURL() const;
+
int32_t maxDriversPerTask() const;
int32_t driverMaxSplitPreload() const;
diff --git a/presto-native-execution/presto_cpp/main/functions/CMakeLists.txt b/presto-native-execution/presto_cpp/main/functions/CMakeLists.txt
index 88aa887dd400..20020ea182e5 100644
--- a/presto-native-execution/presto_cpp/main/functions/CMakeLists.txt
+++ b/presto-native-execution/presto_cpp/main/functions/CMakeLists.txt
@@ -11,10 +11,14 @@
# limitations under the License.
add_library(presto_function_metadata OBJECT FunctionMetadata.cpp)
-target_link_libraries(presto_function_metadata velox_function_registry)
+target_link_libraries(presto_function_metadata presto_common velox_function_registry)
add_subdirectory(dynamic_registry)
+if(PRESTO_ENABLE_REMOTE_FUNCTIONS)
+ add_subdirectory(remote)
+endif()
+
if(PRESTO_ENABLE_TESTING)
add_subdirectory(tests)
endif()
diff --git a/presto-native-execution/presto_cpp/main/functions/remote/CMakeLists.txt b/presto-native-execution/presto_cpp/main/functions/remote/CMakeLists.txt
new file mode 100644
index 000000000000..c2c4f8c7990f
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/functions/remote/CMakeLists.txt
@@ -0,0 +1,28 @@
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+add_library(presto_to_velox_remote_functions PrestoRestFunctionRegistration.cpp)
+target_link_libraries(
+ presto_to_velox_remote_functions
+ presto_functions_remote
+ velox_type_fbhive
+ Boost::url
+)
+
+add_library(presto_functions_remote RestRemoteFunction.cpp)
+target_link_libraries(presto_functions_remote presto_functions_rest_client velox_functions_remote)
+
+add_subdirectory(client)
+
+if(${PRESTO_ENABLE_TESTING})
+ add_subdirectory(tests)
+endif()
diff --git a/presto-native-execution/presto_cpp/main/functions/remote/PrestoRestFunctionRegistration.cpp b/presto-native-execution/presto_cpp/main/functions/remote/PrestoRestFunctionRegistration.cpp
new file mode 100644
index 000000000000..bddde97e28e1
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/functions/remote/PrestoRestFunctionRegistration.cpp
@@ -0,0 +1,173 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "presto_cpp/main/functions/remote/PrestoRestFunctionRegistration.h"
+
+#include
+#include
+
+#include "presto_cpp/main/common/Configs.h"
+#include "presto_cpp/main/common/Utils.h"
+#include "presto_cpp/main/functions/remote/RestRemoteFunction.h"
+#include "presto_cpp/main/functions/remote/client/RestRemoteClient.h"
+
+using facebook::velox::functions::remote::PageFormat;
+
+namespace facebook::presto::functions::remote::rest {
+namespace {
+// Returns the serialization/deserialization format used by the remote function
+// server. The format is determined by the system configuration value
+// "remoteFunctionServerSerde". Supported formats:
+// - "presto_page": Uses Presto page format.
+// - "spark_unsafe_row": Uses Spark unsafe row format.
+// @return PageFormat enum value corresponding to the configured serde format.
+// @throws VeloxException if the configured format is unknown.
+PageFormat getSerdeFormat() {
+ static const auto serdeFormat =
+ SystemConfig::instance()->remoteFunctionServerSerde();
+ if (serdeFormat == "presto_page") {
+ return PageFormat::PRESTO_PAGE;
+ } else if (serdeFormat == "spark_unsafe_row") {
+ return PageFormat::SPARK_UNSAFE_ROW;
+ } else {
+ VELOX_FAIL(
+ "Unknown serde name for remote function server: '{}'", serdeFormat);
+ }
+}
+
+// Encodes a string for safe inclusion in a URL by escaping non-alphanumeric
+// characters using percent-encoding. Alphanumeric characters and '-', '_', '.',
+// '~' are left unchanged. All other characters are replaced with '%' followed
+// by their two-digit hexadecimal value.
+// @param value The input string to encode.
+// @return The URL-encoded string.
+std::string urlEncode(const std::string& value) {
+ return boost::urls::encode(value, boost::urls::unreserved_chars);
+}
+
+std::string getFunctionName(const protocol::SqlFunctionId& functionId) {
+ // Example: "namespace.schema.function;TYPE;TYPE".
+ const auto nameEnd = functionId.find(';');
+ // Assuming the possibility of missing ';' if there are no function arguments.
+ return nameEnd != std::string::npos ? functionId.substr(0, nameEnd)
+ : functionId;
+}
+
+// Constructs a Velox function signature from a Presto function signature. This
+// function translates type variable constraints, integer variable constraints,
+// return type, argument types, and variable arity from the Presto signature to
+// the corresponding Velox signature builder.
+// @param prestoSignature The Presto function signature to convert.
+// @return A pointer to the constructed Velox function signature.
+velox::exec::FunctionSignaturePtr buildVeloxSignatureFromPrestoSignature(
+ const protocol::Signature& prestoSignature) {
+ velox::exec::FunctionSignatureBuilder signatureBuilder;
+
+ for (const auto& typeVar : prestoSignature.typeVariableConstraints) {
+ signatureBuilder.typeVariable(typeVar.name);
+ }
+
+ for (const auto& longVar : prestoSignature.longVariableConstraints) {
+ signatureBuilder.integerVariable(longVar.name);
+ }
+ signatureBuilder.returnType(prestoSignature.returnType);
+
+ for (const auto& argType : prestoSignature.argumentTypes) {
+ signatureBuilder.argumentType(argType);
+ }
+
+ if (prestoSignature.variableArity) {
+ signatureBuilder.variableArity();
+ }
+ return signatureBuilder.build();
+}
+
+} // namespace
+
+void registerRestRemoteFunction(
+ const protocol::RestFunctionHandle& restFunctionHandle) {
+ static std::mutex registrationMutex;
+ static std::unordered_map registeredFunctionHandles;
+ static std::unordered_map
+ restClient;
+ static const std::string remoteFunctionServerRestURL =
+ SystemConfig::instance()->remoteFunctionServerRestURL();
+
+ const std::string functionId = restFunctionHandle.functionId;
+
+ json functionHandleJson;
+ to_json(functionHandleJson, restFunctionHandle);
+ functionHandleJson["url"] = remoteFunctionServerRestURL;
+ const std::string serializedFunctionHandle = functionHandleJson.dump();
+
+ // Check if already registered (read-only, no lock needed for initial check)
+ {
+ std::lock_guard lock(registrationMutex);
+ auto it = registeredFunctionHandles.find(functionId);
+ if (it != registeredFunctionHandles.end() &&
+ it->second == serializedFunctionHandle) {
+ return;
+ }
+ }
+
+ // Get or create shared RestRemoteClient for this server URL
+ functions::rest::RestRemoteClientPtr remoteClient;
+ {
+ std::lock_guard lock(registrationMutex);
+ auto clientIt = restClient.find(remoteFunctionServerRestURL);
+ if (clientIt == restClient.end()) {
+ restClient[remoteFunctionServerRestURL] =
+ std::make_shared(
+ remoteFunctionServerRestURL);
+ }
+ remoteClient = restClient[remoteFunctionServerRestURL];
+ }
+
+ functions::rest::VeloxRemoteFunctionMetadata metadata;
+
+ // Extract function name parts using the utility function
+ const std::string functionName =
+ getFunctionName(restFunctionHandle.functionId);
+ const auto parts = util::getFunctionNameParts(functionName);
+ const std::string schema = parts[1];
+ const std::string function = parts[2];
+
+ const std::string functionLocation = fmt::format(
+ "{}/v1/functions/{}/{}/{}/{}",
+ remoteFunctionServerRestURL,
+ schema,
+ function,
+ urlEncode(restFunctionHandle.functionId),
+ restFunctionHandle.version);
+ metadata.location = functionLocation;
+ metadata.serdeFormat = getSerdeFormat();
+
+ auto veloxSignature =
+ buildVeloxSignatureFromPrestoSignature(restFunctionHandle.signature);
+ std::vector veloxSignatures = {
+ veloxSignature};
+
+ functions::rest::registerVeloxRemoteFunction(
+ getFunctionName(restFunctionHandle.functionId),
+ veloxSignatures,
+ metadata,
+ remoteClient);
+
+ // Update registration map
+ {
+ std::lock_guard lock(registrationMutex);
+ registeredFunctionHandles[functionId] = serializedFunctionHandle;
+ }
+}
+} // namespace facebook::presto::functions::remote::rest
diff --git a/presto-native-execution/presto_cpp/main/functions/remote/PrestoRestFunctionRegistration.h b/presto-native-execution/presto_cpp/main/functions/remote/PrestoRestFunctionRegistration.h
new file mode 100644
index 000000000000..fcaf3d086f18
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/functions/remote/PrestoRestFunctionRegistration.h
@@ -0,0 +1,23 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "presto_cpp/presto_protocol/presto_protocol.h"
+
+namespace facebook::presto::functions::remote::rest {
+
+void registerRestRemoteFunction(
+ const protocol::RestFunctionHandle& restFunctionHandle);
+} // namespace facebook::presto::functions::remote::rest
diff --git a/presto-native-execution/presto_cpp/main/functions/remote/RestRemoteFunction.cpp b/presto-native-execution/presto_cpp/main/functions/remote/RestRemoteFunction.cpp
new file mode 100644
index 000000000000..b8f1c3cade7e
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/functions/remote/RestRemoteFunction.cpp
@@ -0,0 +1,103 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "presto_cpp/main/functions/remote/RestRemoteFunction.h"
+#include "presto_cpp/main/functions/remote/client/RestRemoteClient.h"
+#include "velox/functions/remote/client/RemoteVectorFunction.h"
+
+using namespace facebook::velox;
+namespace facebook::presto::functions::rest {
+namespace {
+
+class RestRemoteFunction : public velox::functions::RemoteVectorFunction {
+ public:
+ RestRemoteFunction(
+ const std::string& functionName,
+ const std::vector& inputArgs,
+ const VeloxRemoteFunctionMetadata& metadata,
+ RestRemoteClientPtr restClient)
+ : RemoteVectorFunction(functionName, inputArgs, metadata),
+ location_(metadata.location),
+ serdeFormat_(metadata.serdeFormat),
+ restClient_(std::move(restClient)) {}
+
+ protected:
+ std::unique_ptr
+ invokeRemoteFunction(
+ const velox::functions::remote::RemoteFunctionRequest& request)
+ const override {
+ VELOX_CHECK(restClient_, "Remote client not initialized.");
+
+ // Clone the request payload for the REST call
+ auto requestBody = request.inputs()->payload()->clone();
+
+ auto responseBody = restClient_->invokeFunction(
+ location_, serdeFormat_, std::move(requestBody));
+
+ if (!responseBody) {
+ VELOX_FAIL("No response received from remote function invocation.");
+ }
+
+ // Convert REST response to RemoteFunctionResponse
+ auto response =
+ std::make_unique();
+ velox::functions::remote::RemoteFunctionPage result;
+ result.payload_ref() = std::move(*responseBody);
+ response->result_ref() = std::move(result);
+ return response;
+ }
+
+ std::string remoteLocationToString() const override {
+ return location_;
+ }
+
+ private:
+ const std::string location_;
+ const velox::functions::remote::PageFormat serdeFormat_;
+ const RestRemoteClientPtr restClient_;
+};
+
+std::shared_ptr createRestRemoteFunction(
+ const std::string& name,
+ const std::vector& inputArgs,
+ const core::QueryConfig& /*config*/,
+ const VeloxRemoteFunctionMetadata& metadata,
+ RestRemoteClientPtr restClient) {
+ return std::make_shared(
+ name, inputArgs, metadata, restClient);
+}
+
+} // namespace
+
+void registerVeloxRemoteFunction(
+ const std::string& name,
+ const std::vector& signatures,
+ VeloxRemoteFunctionMetadata metadata,
+ RestRemoteClientPtr restClient,
+ bool overwrite) {
+ registerStatefulVectorFunction(
+ name,
+ signatures,
+ std::bind(
+ createRestRemoteFunction,
+ std::placeholders::_1,
+ std::placeholders::_2,
+ std::placeholders::_3,
+ metadata,
+ restClient),
+ std::move(metadata),
+ overwrite);
+}
+
+} // namespace facebook::presto::functions::rest
diff --git a/presto-native-execution/presto_cpp/main/functions/remote/RestRemoteFunction.h b/presto-native-execution/presto_cpp/main/functions/remote/RestRemoteFunction.h
new file mode 100644
index 000000000000..408fc6700cc6
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/functions/remote/RestRemoteFunction.h
@@ -0,0 +1,35 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "presto_cpp/main/functions/remote/client/RestRemoteClient.h"
+#include "velox/functions/remote/client/RemoteVectorFunction.h"
+
+namespace facebook::presto::functions::rest {
+
+struct VeloxRemoteFunctionMetadata
+ : public velox::functions::RemoteVectorFunctionMetadata {
+ /// URL of the HTTP/REST server for remote function.
+ std::string location;
+};
+
+void registerVeloxRemoteFunction(
+ const std::string& name,
+ const std::vector& signatures,
+ VeloxRemoteFunctionMetadata metadata,
+ RestRemoteClientPtr restClient,
+ bool overwrite = true);
+
+} // namespace facebook::presto::functions::rest
diff --git a/presto-native-execution/presto_cpp/main/functions/remote/client/CMakeLists.txt b/presto-native-execution/presto_cpp/main/functions/remote/client/CMakeLists.txt
new file mode 100644
index 000000000000..0b714aff8d91
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/functions/remote/client/CMakeLists.txt
@@ -0,0 +1,14 @@
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+add_library(presto_functions_rest_client RestRemoteClient.cpp)
+target_link_libraries(presto_functions_rest_client presto_http)
diff --git a/presto-native-execution/presto_cpp/main/functions/remote/client/RestRemoteClient.cpp b/presto-native-execution/presto_cpp/main/functions/remote/client/RestRemoteClient.cpp
new file mode 100644
index 000000000000..0d8b94710fe8
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/functions/remote/client/RestRemoteClient.cpp
@@ -0,0 +1,113 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "presto_cpp/main/functions/remote/client/RestRemoteClient.h"
+
+#include
+#include
+
+#include "presto_cpp/main/functions/remote/utils/ContentTypes.h"
+#include "velox/common/base/Exceptions.h"
+#include "velox/common/memory/Memory.h"
+
+using namespace facebook::velox;
+
+namespace facebook::presto::functions::rest {
+namespace {
+inline std::string getContentType(velox::functions::remote::PageFormat fmt) {
+ return fmt == velox::functions::remote::PageFormat::SPARK_UNSAFE_ROW
+ ? remote::CONTENT_TYPE_SPARK_UNSAFE_ROW
+ : remote::CONTENT_TYPE_PRESTO_PAGE;
+}
+} // namespace
+
+RestRemoteClient::RestRemoteClient(const std::string& url) : url_(url) {
+ memPool_ = memory::MemoryManager::getInstance()->addLeafPool();
+ folly::Uri uri(url_);
+ proxygen::Endpoint endpoint(uri.host(), uri.port(), uri.scheme() == "https");
+ folly::SocketAddress addr(uri.host().c_str(), uri.port(), true);
+
+ evbThread_ = std::make_unique("rest-client");
+ httpClient_ = std::make_shared(
+ evbThread_->getEventBase(),
+ nullptr,
+ endpoint,
+ addr,
+ requestTimeoutMs,
+ connectTimeoutMs,
+ memPool_,
+ nullptr);
+}
+
+RestRemoteClient::~RestRemoteClient() {
+ if (httpClient_) {
+ evbThread_->getEventBase()->runInEventBaseThreadAndWait(
+ [client = std::move(httpClient_)]() mutable { client.reset(); });
+ }
+ evbThread_.reset();
+}
+
+std::unique_ptr RestRemoteClient::invokeFunction(
+ const std::string& fullUrl,
+ velox::functions::remote::PageFormat serdeFormat,
+ std::unique_ptr requestPayload) const {
+ try {
+ folly::Uri uri(fullUrl);
+ const std::string contentType = getContentType(serdeFormat);
+ auto message = std::make_unique();
+ message->setMethod(proxygen::HTTPMethod::POST);
+ message->setURL(uri.path());
+ message->setHTTPVersion(1, 1);
+ message->getHeaders().add("Content-Type", contentType);
+ message->getHeaders().add("Accept", contentType);
+
+ requestPayload->coalesce();
+ std::string requestBody = requestPayload->moveToFbString().toStdString();
+
+ auto sendFuture = httpClient_->sendRequest(*message, requestBody);
+ sendFuture.wait();
+
+ VELOX_CHECK(
+ sendFuture.hasValue(),
+ "Invalid response returned from HTTP request to {}.",
+ uri.host());
+
+ std::unique_ptr resp = std::move(sendFuture).get();
+
+ if (!resp) {
+ VELOX_FAIL(
+ "Null response object returned from HTTP request to {}.", uri.host());
+ }
+
+ if (resp->hasError()) {
+ VELOX_FAIL("HTTP error: {}", resp->error());
+ }
+
+ int status = resp->headers()->getStatusCode();
+ if (status < http::kHttpOk || status >= http::kHttpMultipleChoices) {
+ VELOX_FAIL(
+ "Server responded with status {}. Body: '{}'. URL: {}",
+ status,
+ resp->dumpBodyChain(),
+ fullUrl);
+ }
+
+ return folly::IOBuf::copyBuffer(resp->dumpBodyChain());
+ } catch (const std::exception& ex) {
+ VELOX_FAIL("HTTP invocation failed for URL {}: {}", fullUrl, ex.what());
+ }
+ return nullptr;
+}
+
+} // namespace facebook::presto::functions::rest
diff --git a/presto-native-execution/presto_cpp/main/functions/remote/client/RestRemoteClient.h b/presto-native-execution/presto_cpp/main/functions/remote/client/RestRemoteClient.h
new file mode 100644
index 000000000000..9fe89944c720
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/functions/remote/client/RestRemoteClient.h
@@ -0,0 +1,53 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include
+
+#include "presto_cpp/main/common/Configs.h"
+#include "presto_cpp/main/http/HttpClient.h"
+#include "velox/functions/remote/if/gen-cpp2/RemoteFunction_types.h"
+
+namespace facebook::presto::functions::rest {
+
+class RestRemoteClient {
+ public:
+ RestRemoteClient(const std::string& url);
+
+ ~RestRemoteClient();
+
+ std::unique_ptr invokeFunction(
+ const std::string& fullUrl,
+ velox::functions::remote::PageFormat serdeFormat,
+ std::unique_ptr requestPayload) const;
+
+ private:
+ const std::string url_;
+ std::unique_ptr evbThread_;
+ std::shared_ptr httpClient_;
+ std::shared_ptr memPool_;
+
+ const std::chrono::milliseconds requestTimeoutMs =
+ std::chrono::duration_cast(
+ SystemConfig::instance()->exchangeRequestTimeoutMs());
+
+ const std::chrono::milliseconds connectTimeoutMs =
+ std::chrono::duration_cast(
+ SystemConfig::instance()->exchangeConnectTimeoutMs());
+};
+
+using RestRemoteClientPtr = std::shared_ptr;
+
+} // namespace facebook::presto::functions::rest
diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/CMakeLists.txt b/presto-native-execution/presto_cpp/main/functions/remote/tests/CMakeLists.txt
new file mode 100644
index 000000000000..59baa1eb24fe
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/CMakeLists.txt
@@ -0,0 +1,30 @@
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if(${PRESTO_ENABLE_TESTING})
+ add_subdirectory(server)
+endif()
+
+add_executable(presto_functions_rest_test RemoteFunctionRestTest.cpp)
+
+add_test(presto_functions_rest_test presto_functions_rest_test)
+
+target_link_libraries(
+ presto_functions_rest_test
+ presto_functions_rest_client
+ presto_functions_rest_server
+ presto_functions_remote
+ velox_functions_test_lib
+ GTest::gmock
+ GTest::gtest
+ GTest::gtest_main
+)
diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/RemoteFunctionRestTest.cpp b/presto-native-execution/presto_cpp/main/functions/remote/tests/RemoteFunctionRestTest.cpp
new file mode 100644
index 000000000000..ddf233b7647d
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/RemoteFunctionRestTest.cpp
@@ -0,0 +1,299 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include
+
+#include
+#include
+#include
+#include
+
+#include "presto_cpp/main/functions/remote/RestRemoteFunction.h"
+#include "presto_cpp/main/functions/remote/client/RestRemoteClient.h"
+#include "presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestService.h"
+#include "presto_cpp/main/functions/remote/tests/server/RestFunctionRegistry.h"
+#include "presto_cpp/main/functions/remote/tests/server/examples/RemoteDoubleDivHandler.h"
+#include "presto_cpp/main/functions/remote/tests/server/examples/RemoteFibonacciHandler.h"
+#include "presto_cpp/main/functions/remote/tests/server/examples/RemoteInverseCdfHandler.h"
+#include "presto_cpp/main/functions/remote/tests/server/examples/RemoteRemoveCharHandler.h"
+#include "presto_cpp/main/functions/remote/tests/server/examples/RemoteStrLenHandler.h"
+#include "velox/common/base/tests/GTestUtils.h"
+#include "velox/exec/tests/utils/PortUtil.h"
+#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"
+#include "velox/type/fbhive/HiveTypeParser.h"
+
+using ::facebook::velox::test::assertEqualVectors;
+using namespace facebook::velox;
+namespace facebook::presto::functions::rest::test {
+namespace {
+
+class RemoteFunctionRestTest
+ : public velox::functions::test::FunctionBaseTest,
+ public testing::WithParamInterface {
+ public:
+ void SetUp() override {
+ auto servicePort = exec::test::getFreePort();
+ location_ = fmt::format(kHostAddressTemplate_, servicePort);
+ wrongLocation_ =
+ fmt::format(kHostAddressTemplate_, exec::test::getFreePort());
+
+ restClient_ = std::make_shared(location_);
+ wrongRestClient_ = std::make_shared(wrongLocation_);
+
+ initializeServer(servicePort);
+ registerRemoteFunctions();
+ }
+
+ ~RemoteFunctionRestTest() override {
+ if (serverThread_ && serverThread_->joinable()) {
+ ioc_.stop();
+ serverThread_->join();
+ }
+ }
+
+ private:
+ // Registers a remote function by creating a handler instance and registering
+ // it on both the server and client side. The handler provides its own input
+ // and output types, eliminating the need for manual type specification.
+ template
+ void registerRemoteFunctionHelper(
+ const std::string& functionName,
+ const std::string& baseLocation,
+ RestRemoteClientPtr client) const {
+ auto handler = std::make_shared();
+
+ auto inputTypes = handler->getInputTypes();
+ auto outputType = handler->getOutputType();
+
+ auto signatureBuilder = exec::FunctionSignatureBuilder();
+ signatureBuilder.returnType(outputType->toString());
+ for (const auto& childType : inputTypes->children()) {
+ signatureBuilder.argumentType(childType->toString());
+ }
+
+ RestFunctionRegistry::getInstance().registerFunction(functionName, handler);
+
+ VeloxRemoteFunctionMetadata metadata;
+ metadata.serdeFormat = GetParam();
+ metadata.location = baseLocation + "/" + functionName;
+ registerVeloxRemoteFunction(
+ functionName, {signatureBuilder.build()}, metadata, client);
+ }
+
+ void registerRemoteFunctions() const {
+ registerRemoteFunctionHelper(
+ "remote_fibonacci", location_, restClient_);
+ registerRemoteFunctionHelper(
+ "remote_strlen", location_, restClient_);
+ registerRemoteFunctionHelper(
+ "remote_remove_char", location_, restClient_);
+ registerRemoteFunctionHelper(
+ "remote_inverse_cdf", location_, restClient_);
+ registerRemoteFunctionHelper(
+ "remote_divide", location_, restClient_);
+ registerRemoteFunctionHelper(
+ "remote_wrong_port", wrongLocation_, wrongRestClient_);
+
+ // Register a fake function handler whose logic is intentionally not
+ // implemented in the server. This simulates a failure scenario for testing
+ // purposes.
+ auto roundSignatures = {exec::FunctionSignatureBuilder()
+ .returnType("integer")
+ .argumentType("integer")
+ .build()};
+ VeloxRemoteFunctionMetadata metadata;
+ metadata.serdeFormat = GetParam();
+ metadata.location = location_ + "/remote_round";
+ registerVeloxRemoteFunction(
+ "remote_round", roundSignatures, metadata, restClient_);
+ }
+
+ void initializeServer(uint16_t servicePort) {
+ serverThread_ = std::make_unique([this, servicePort]() {
+ std::make_shared(
+ ioc_,
+ boost::asio::ip::tcp::endpoint(
+ boost::asio::ip::make_address(kServiceHost), servicePort))
+ ->run();
+ ioc_.run();
+ });
+
+ VELOX_CHECK(
+ waitForRunning(servicePort), "Unable to initialize HTTP server.");
+ }
+
+ bool waitForRunning(uint16_t servicePort) const {
+ for (size_t i = 0; i < 100; ++i) {
+ using boost::asio::ip::tcp;
+ boost::asio::io_context io_context;
+
+ tcp::socket socket(io_context);
+ tcp::resolver resolver(io_context);
+
+ try {
+ boost::asio::connect(
+ socket,
+ resolver.resolve(kServiceHost, std::to_string(servicePort)));
+ return true;
+ } catch (std::exception& e) {
+ std::this_thread::sleep_for(std::chrono::milliseconds(500));
+ }
+ }
+ return false;
+ }
+
+ std::unique_ptr serverThread_;
+ boost::asio::io_context ioc_{1};
+ static constexpr auto kHostAddressTemplate_ = "http://127.0.0.1:{}";
+ static constexpr auto kServiceHost = "127.0.0.1";
+
+ std::string location_;
+ std::string wrongLocation_;
+ RestRemoteClientPtr restClient_;
+ RestRemoteClientPtr wrongRestClient_;
+};
+
+TEST_P(RemoteFunctionRestTest, connectionError) {
+ auto numeratorVector = makeFlatVector({0, 1, 4, 9, 16, 25, -25});
+ auto denominatorVector = makeFlatVector({0, 1, 2, 3, 4, 0, 2});
+ auto data = makeRowVector({numeratorVector, denominatorVector});
+ VELOX_ASSERT_THROW(
+ evaluate>("remote_wrong_port(c0,c1)", data),
+ "HTTP invocation failed for URL");
+}
+
+TEST_P(RemoteFunctionRestTest, fibonacci) {
+ auto inputVector = makeFlatVector({10, 20});
+ auto results = evaluate>(
+ "remote_fibonacci(c0)", makeRowVector({inputVector}));
+
+ auto expected = makeFlatVector({55, 6765});
+ assertEqualVectors(expected, results);
+}
+
+TEST_P(RemoteFunctionRestTest, stringLength) {
+ auto inputVector =
+ makeFlatVector({"hello", "from", "remote", "server"});
+ auto results = evaluate>(
+ "remote_strlen(c0)", makeRowVector({inputVector}));
+
+ auto expected = makeFlatVector({5, 4, 6, 6});
+ assertEqualVectors(expected, results);
+}
+
+TEST_P(RemoteFunctionRestTest, removeChar) {
+ auto input = makeFlatVector(
+ {"hello from remote server",
+ "testing remote server",
+ "My file, named 'data_report#2.csv', is located in the folder: C:\\Users\\User\\Documents! It's quite large (~1.5GB)."});
+ auto charToRemove = makeFlatVector({"o", "e", "c"});
+ auto results = evaluate>(
+ "remote_remove_char(c0,c1)", makeRowVector({input, charToRemove}));
+
+ auto expected = makeFlatVector(
+ {"hell frm remte server",
+ "tsting rmot srvr",
+ "My file, named 'data_report#2.sv', is loated in the folder: C:\\Users\\User\\Douments! It's quite large (~1.5GB)."});
+ assertEqualVectors(expected, results);
+}
+
+TEST_P(RemoteFunctionRestTest, tryException) {
+ // remote_divide throws if denominator is 0.
+ auto numeratorVector = makeFlatVector({0, 1, 4, 9, 16, 25, -25});
+ auto denominatorVector = makeFlatVector({0, 1, 2, 3, 4, 0, 2});
+ auto data = makeRowVector({numeratorVector, denominatorVector});
+ auto results = evaluate>("remote_divide(c0, c1)", data);
+
+ ASSERT_EQ(results->size(), 7);
+ auto expected = makeFlatVector({0, 1, 2, 3, 4, 0, -12.5});
+ expected->setNull(0, true);
+ expected->setNull(5, true);
+
+ assertEqualVectors(expected, results);
+}
+
+TEST_P(RemoteFunctionRestTest, inverseCdf) {
+ auto pVector = makeFlatVector({0.95, 0.95, 0.50, 0.10});
+ auto nuVector = makeFlatVector({4, 1, 10, 2});
+ auto data = makeRowVector({pVector, nuVector});
+ auto results =
+ evaluate>("remote_inverse_cdf(c0, c1)", data);
+
+ ASSERT_EQ(results->size(), 4);
+ auto expected = makeFlatVector({9.49, 3.84, 9.34, 0.21});
+
+ assertEqualVectors(expected, results);
+}
+
+TEST_P(RemoteFunctionRestTest, inverseCdfException) {
+ // p < 0, p > 1, nu <= 0 - these should throw exceptions
+ auto pVector = makeFlatVector({-0.1});
+ auto nuVector = makeFlatVector({4});
+ auto data = makeRowVector({pVector, nuVector});
+
+ VELOX_ASSERT_THROW(
+ evaluate>("remote_inverse_cdf(c0, c1)", data),
+ "inverse_chi_squared_cdf: p must be in (0,1)");
+
+ // Test p > 1
+ pVector = makeFlatVector({1.1});
+ nuVector = makeFlatVector({1});
+ data = makeRowVector({pVector, nuVector});
+
+ VELOX_ASSERT_THROW(
+ evaluate>("remote_inverse_cdf(c0, c1)", data),
+ "inverse_chi_squared_cdf: p must be in (0,1)");
+
+ // Test nu <= 0
+ pVector = makeFlatVector({0.5});
+ nuVector = makeFlatVector({0});
+ data = makeRowVector({pVector, nuVector});
+
+ VELOX_ASSERT_THROW(
+ evaluate>("remote_inverse_cdf(c0, c1)", data),
+ "inverse_chi_squared_cdf: degrees of freedom must be > 0");
+
+ // Test nu < 0
+ pVector = makeFlatVector({0.95});
+ nuVector = makeFlatVector({-2});
+ data = makeRowVector({pVector, nuVector});
+
+ VELOX_ASSERT_THROW(
+ evaluate>("remote_inverse_cdf(c0, c1)", data),
+ "inverse_chi_squared_cdf: degrees of freedom must be > 0");
+}
+
+TEST_P(RemoteFunctionRestTest, functionNotAvailable) {
+ auto inputVector = makeFlatVector({-10, -20});
+ VELOX_ASSERT_THROW(
+ evaluate>(
+ "remote_round(c0)", makeRowVector({inputVector})),
+ "Server responded with status 400. Body: 'Function 'remote_round' is not available.'");
+}
+
+VELOX_INSTANTIATE_TEST_SUITE_P(
+ RemoteFunctionRestTestFixture,
+ RemoteFunctionRestTest,
+ ::testing::Values(
+ velox::functions::remote::PageFormat::PRESTO_PAGE,
+ velox::functions::remote::PageFormat::SPARK_UNSAFE_ROW));
+
+} // namespace
+} // namespace facebook::presto::functions::rest::test
+
+int main(int argc, char** argv) {
+ testing::InitGoogleTest(&argc, argv);
+ folly::Init init{&argc, &argv, false};
+ return RUN_ALL_TESTS();
+}
diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/CMakeLists.txt b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/CMakeLists.txt
new file mode 100644
index 000000000000..9cd7cd3bcc41
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/CMakeLists.txt
@@ -0,0 +1,14 @@
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+add_library(presto_functions_rest_server RemoteFunctionRestService.cpp RestFunctionRegistry.cpp)
+target_link_libraries(presto_functions_rest_server velox_presto_serializer)
diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestHandler.h b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestHandler.h
new file mode 100644
index 000000000000..b33dac334f86
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestHandler.h
@@ -0,0 +1,77 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include
+
+#include "velox/vector/VectorStream.h"
+
+namespace facebook::presto::functions::rest {
+
+class RemoteFunctionRestHandler {
+ public:
+ RemoteFunctionRestHandler() = default;
+
+ virtual ~RemoteFunctionRestHandler() = default;
+
+ virtual velox::RowTypePtr getInputTypes() const = 0;
+ virtual velox::TypePtr getOutputType() const = 0;
+
+ folly::IOBuf handleRequest(
+ std::unique_ptr inputBuffer,
+ velox::VectorSerde* serde,
+ velox::memory::MemoryPool* pool,
+ std::string& errorMessage) {
+ auto inputTypes = getInputTypes();
+ auto outputType = getOutputType();
+
+ auto inputVector = IOBufToRowVector(*inputBuffer, inputTypes, *pool, serde);
+
+ VELOX_CHECK_EQ(
+ inputVector->childrenSize(),
+ inputTypes->children().size(),
+ "Mismatched number of columns for remote function handler.");
+
+ const auto numRows = inputVector->size();
+ auto resultVector = velox::BaseVector::create(outputType, numRows, pool);
+
+ compute(inputVector, resultVector, errorMessage);
+
+ if (!errorMessage.empty()) {
+ return folly::IOBuf();
+ }
+
+ // Wrap the result in a RowVector to send back.
+ auto outputRowVector = std::make_shared(
+ pool,
+ velox::ROW({outputType}),
+ velox::BufferPtr(),
+ numRows,
+ std::vector{resultVector});
+
+ auto payload = rowVectorToIOBuf(
+ outputRowVector, outputRowVector->size(), *pool, serde);
+
+ return payload;
+ }
+
+ protected:
+ // Core computation function to be implemented by subclasses.
+ virtual void compute(
+ const velox::RowVectorPtr& inputVector,
+ const velox::VectorPtr& resultVector,
+ std::string& errorMessage) = 0;
+};
+
+} // namespace facebook::presto::functions::rest
diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestService.cpp b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestService.cpp
new file mode 100644
index 000000000000..2c944687f1f3
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestService.cpp
@@ -0,0 +1,314 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestService.h"
+
+#include
+
+#include "presto_cpp/main/functions/remote/tests/server/RestFunctionRegistry.h"
+#include "presto_cpp/main/functions/remote/utils/ContentTypes.h"
+#include "velox/serializers/PrestoSerializer.h"
+#include "velox/serializers/UnsafeRowSerializer.h"
+
+using namespace facebook::velox;
+namespace facebook::presto::functions::rest {
+
+RestSession::RestSession(boost::asio::ip::tcp::socket socket)
+ : socket_(std::move(socket)),
+ pool_(memory::memoryManager()->addLeafPool()) {}
+
+void RestSession::run() {
+ doRead();
+}
+
+void RestSession::doRead() {
+ auto self = shared_from_this();
+ boost::beast::http::async_read(
+ socket_,
+ buffer_,
+ req_,
+ [self](boost::beast::error_code ec, size_t bytes_transferred) {
+ self->onRead(ec, bytes_transferred);
+ });
+}
+
+void RestSession::onRead(
+ boost::beast::error_code ec,
+ std::size_t bytes_transferred) {
+ boost::ignore_unused(bytes_transferred);
+
+ if (ec == boost::beast::http::error::end_of_stream) {
+ return doClose();
+ }
+
+ if (ec) {
+ LOG(ERROR) << "Read error: " << ec.message();
+ return;
+ }
+
+ handleRequest(std::move(req_));
+}
+
+void RestSession::handleRequest(
+ const boost::beast::http::request& req) {
+ res_.version(req.version());
+ res_.set(boost::beast::http::field::server, BOOST_BEAST_VERSION_STRING);
+
+ if (!ensurePostMethod(req)) {
+ return;
+ }
+
+ if (!ensureValidContentType(req)) {
+ return;
+ }
+
+ if (!ensureValidAcceptHeader(req)) {
+ return;
+ }
+
+ std::string functionName;
+ if (!extractFunctionName(req, functionName)) {
+ return;
+ }
+
+ try {
+ auto& registry = RestFunctionRegistry::getInstance();
+ auto handler = registry.getFunction(functionName);
+ if (!handler) {
+ sendResponse(
+ boost::beast::http::status::bad_request,
+ plainText_,
+ fmt::format("Function '{}' is not available.", functionName),
+ false);
+ return;
+ }
+
+ std::unique_ptr serde;
+ if (contentType_ == remote::CONTENT_TYPE_PRESTO_PAGE) {
+ serde = std::make_unique();
+ } else {
+ serde = std::make_unique();
+ }
+ auto inputBuffer = folly::IOBuf::copyBuffer(req.body());
+
+ std::string errorMessage;
+ auto outputBuffer = handler->handleRequest(
+ std::move(inputBuffer), serde.get(), pool_.get(), errorMessage);
+ if (!errorMessage.empty()) {
+ sendResponse(
+ boost::beast::http::status::internal_server_error,
+ plainText_,
+ errorMessage,
+ true);
+ return;
+ }
+ sendSuccessResponse(std::move(outputBuffer));
+ } catch (const std::exception& ex) {
+ handleException(ex);
+ }
+}
+
+void RestSession::sendResponse(
+ boost::beast::http::status status,
+ const std::string& contentType,
+ const std::string& body,
+ bool closeConnection) {
+ res_.result(status);
+ res_.set(boost::beast::http::field::content_type, contentType);
+ res_.body() = body;
+ res_.prepare_payload();
+
+ auto self = shared_from_this();
+ boost::beast::http::async_write(
+ socket_,
+ res_,
+ [self, closeConnection](
+ boost::beast::error_code ec, std::size_t bytes_transferred) {
+ self->onWrite(closeConnection, ec, bytes_transferred);
+ });
+}
+
+bool RestSession::ensurePostMethod(
+ const boost::beast::http::request& req) {
+ if (req.method() == boost::beast::http::verb::post) {
+ return true;
+ }
+ auto msg = fmt::format(
+ "Only POST method is allowed. Method used: {}",
+ std::string(req.method_string()));
+ sendResponse(
+ boost::beast::http::status::method_not_allowed, plainText_, msg, true);
+ return false;
+}
+
+bool RestSession::ensureValidContentType(
+ const boost::beast::http::request& req) {
+ auto contentType = req[boost::beast::http::field::content_type];
+ if (!contentType.empty() &&
+ (contentType == remote::CONTENT_TYPE_SPARK_UNSAFE_ROW ||
+ contentType == remote::CONTENT_TYPE_PRESTO_PAGE)) {
+ contentType_ = contentType;
+ return true;
+ }
+
+ auto msg = fmt::format(
+ "Unsupported Content-Type: '{}'. Expecting '{}' or '{}'.",
+ std::string(contentType),
+ remote::CONTENT_TYPE_PRESTO_PAGE,
+ remote::CONTENT_TYPE_SPARK_UNSAFE_ROW);
+ sendResponse(
+ boost::beast::http::status::unsupported_media_type,
+ plainText_,
+ msg,
+ true);
+ return false;
+}
+
+bool RestSession::ensureValidAcceptHeader(
+ const boost::beast::http::request& req) {
+ auto acceptHeader = req[boost::beast::http::field::accept];
+ if (!acceptHeader.empty() &&
+ (acceptHeader == remote::CONTENT_TYPE_PRESTO_PAGE ||
+ acceptHeader == remote::CONTENT_TYPE_SPARK_UNSAFE_ROW)) {
+ return true;
+ }
+
+ auto msg = fmt::format(
+ "Unsupported Accept header: '{}'. Expecting '{}' or '{}'.",
+ std::string(acceptHeader),
+ remote::CONTENT_TYPE_PRESTO_PAGE,
+ remote::CONTENT_TYPE_SPARK_UNSAFE_ROW);
+ sendResponse(
+ boost::beast::http::status::not_acceptable, plainText_, msg, true);
+ return false;
+}
+
+bool RestSession::extractFunctionName(
+ const boost::beast::http::request&
+ requestBody,
+ std::string& functionName) {
+ std::string path = std::string(requestBody.target());
+ std::vector pathComponents;
+ folly::split('/', path, pathComponents);
+
+ if (pathComponents.size() != 2) {
+ sendResponse(
+ boost::beast::http::status::bad_request,
+ plainText_,
+ "Invalid request path",
+ true);
+ return false;
+ }
+
+ // pathComponents[1] contains the function name
+ functionName = pathComponents[1];
+ return true;
+}
+
+void RestSession::handleException(const std::exception& ex) {
+ sendResponse(
+ boost::beast::http::status::internal_server_error,
+ plainText_,
+ ex.what(),
+ true);
+}
+
+void RestSession::sendSuccessResponse(folly::IOBuf&& payload) {
+ sendResponse(
+ boost::beast::http::status::ok,
+ contentType_,
+ payload.moveToFbString().toStdString(),
+ false);
+}
+
+void RestSession::onWrite(
+ bool close,
+ boost::beast::error_code ec,
+ std::size_t bytes_transferred) {
+ boost::ignore_unused(bytes_transferred);
+
+ if (ec) {
+ LOG(ERROR) << "Write error: " << ec.message();
+ return;
+ }
+
+ if (close) {
+ return doClose();
+ }
+
+ req_ = {};
+
+ doRead();
+}
+
+void RestSession::doClose() {
+ boost::beast::error_code ec;
+ socket_.shutdown(boost::asio::ip::tcp::socket::shutdown_send, ec);
+}
+
+RestListener::RestListener(
+ boost::asio::io_context& ioc,
+ boost::asio::ip::tcp::endpoint endpoint)
+ : ioc_(ioc), acceptor_(ioc) {
+ boost::beast::error_code ec;
+
+ acceptor_.open(endpoint.protocol(), ec);
+ if (ec) {
+ LOG(ERROR) << "Open error: " << ec.message();
+ return;
+ }
+
+ acceptor_.set_option(boost::asio::socket_base::reuse_address(true), ec);
+ if (ec) {
+ LOG(ERROR) << "Set_option error: " << ec.message();
+ return;
+ }
+
+ acceptor_.bind(endpoint, ec);
+ if (ec) {
+ LOG(ERROR) << "Bind error: " << ec.message();
+ return;
+ }
+
+ acceptor_.listen(boost::asio::socket_base::max_listen_connections, ec);
+ if (ec) {
+ LOG(ERROR) << "Listen error: " << ec.message();
+ return;
+ }
+}
+
+void RestListener::run() {
+ doAccept();
+}
+
+void RestListener::doAccept() {
+ acceptor_.async_accept(
+ [self = shared_from_this()](
+ boost::beast::error_code ec, boost::asio::ip::tcp::socket socket) {
+ self->onAccept(ec, std::move(socket));
+ });
+}
+
+void RestListener::onAccept(
+ boost::beast::error_code ec,
+ boost::asio::ip::tcp::socket socket) {
+ if (ec) {
+ LOG(ERROR) << "Accept error: " << ec.message();
+ } else {
+ std::make_shared(std::move(socket))->run();
+ }
+ doAccept();
+}
+
+} // namespace facebook::presto::functions::rest
diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestService.h b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestService.h
new file mode 100644
index 000000000000..e24d889dd64f
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestService.h
@@ -0,0 +1,125 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include
+#include
+#include
+
+#include "presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestHandler.h"
+#include "presto_cpp/main/functions/remote/tests/server/RestFunctionRegistry.h"
+
+namespace facebook::presto::functions::rest {
+
+/// @brief Manages an individual HTTP session.
+/// Handles reading HTTP requests, processing them, and sending responses.
+/// Inspired by the reference implementation described in:
+/// https://medium.com/@AlexanderObregon/building-restful-apis-with-c-4c8ac63fe8a7
+class RestSession : public std::enable_shared_from_this {
+ public:
+ RestSession(boost::asio::ip::tcp::socket socket);
+
+ /// Starts the session by initiating a read operation.
+ void run();
+
+ private:
+ // Initiates an asynchronous read operation.
+ void doRead();
+
+ // Called when a read operation completes.
+ void onRead(boost::beast::error_code ec, std::size_t bytes_transferred);
+
+ // Processes the HTTP request and prepares a response.
+ void handleRequest(
+ const boost::beast::http::request& req);
+
+ // Called when a write operation completes.
+ void onWrite(
+ bool close,
+ boost::beast::error_code ec,
+ std::size_t bytes_transferred);
+
+ // Closes the socket connection.
+ void doClose();
+
+ // Helper to handle exceptions: logs error and sends HTTP 500
+ void handleException(const std::exception& ex);
+
+ // Helper to extract function name from /{functionName}
+ bool extractFunctionName(
+ const boost::beast::http::request&
+ requestBody,
+ std::string& functionName);
+
+ // Sends a response with status, contentType, and body. Then calls
+ // async_write.
+ void sendResponse(
+ boost::beast::http::status status,
+ const std::string& contentType,
+ const std::string& body,
+ bool closeConnection);
+
+ // Helper to ensure the HTTP method is POST, else sends an error response.
+ bool ensurePostMethod(
+ const boost::beast::http::request& req);
+
+ // Helper to ensure Content-Type is either "application/X-presto-pages" or
+ // "application/X-spark-unsafe-row". Otherwise sends an error response.
+ bool ensureValidContentType(
+ const boost::beast::http::request& req);
+
+ // Helper to ensure Accept is either "application/X-presto-pages" or
+ // "application/X-spark-unsafe-row". Otherwise sends an error response.
+ bool ensureValidAcceptHeader(
+ const boost::beast::http::request& req);
+
+ // Sends a success response with the given payload.
+ void sendSuccessResponse(folly::IOBuf&& payload);
+
+ boost::asio::ip::tcp::socket socket_;
+ boost::beast::flat_buffer buffer_;
+ boost::beast::http::request req_;
+ boost::beast::http::response res_;
+ std::shared_ptr pool_;
+ std::string contentType_;
+ static constexpr std::string plainText_ = "text/plain";
+};
+
+/// @brief Listens for incoming TCP connections and creates sessions.
+/// Sets up a TCP acceptor to listen for client connections,
+/// creating a new session for each accepted connection.
+class RestListener : public std::enable_shared_from_this {
+ public:
+ RestListener(
+ boost::asio::io_context& ioc,
+ boost::asio::ip::tcp::endpoint endpoint);
+
+ /// Starts accepting incoming connections.
+ void run();
+
+ private:
+ // Initiates an asynchronous accept operation.
+ void doAccept();
+
+ // Called when an accept operation completes.
+ void onAccept(
+ boost::beast::error_code ec,
+ boost::asio::ip::tcp::socket socket);
+
+ boost::asio::io_context& ioc_;
+ boost::asio::ip::tcp::acceptor acceptor_;
+};
+
+} // namespace facebook::presto::functions::rest
diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RestFunctionRegistry.cpp b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RestFunctionRegistry.cpp
new file mode 100644
index 000000000000..4cd56c209cd8
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RestFunctionRegistry.cpp
@@ -0,0 +1,69 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "presto_cpp/main/functions/remote/tests/server/RestFunctionRegistry.h"
+
+#include
+
+namespace facebook::presto::functions::rest {
+
+RestFunctionRegistry& RestFunctionRegistry::getInstance() {
+ static RestFunctionRegistry instance;
+ return instance;
+}
+
+bool RestFunctionRegistry::registerFunction(
+ const std::string& functionName,
+ std::shared_ptr handler) {
+ std::lock_guard lock(mutex_);
+
+ auto it = functionHandlers_.find(functionName);
+ bool replacing = (it != functionHandlers_.end());
+
+ if (replacing) {
+ LOG(WARNING) << "Function handler for '" << functionName
+ << "' is being replaced.";
+ }
+
+ functionHandlers_[functionName] = std::move(handler);
+ return replacing;
+}
+
+bool RestFunctionRegistry::unregisterFunction(const std::string& functionName) {
+ std::lock_guard lock(mutex_);
+ return functionHandlers_.erase(functionName) > 0;
+}
+
+std::shared_ptr RestFunctionRegistry::getFunction(
+ const std::string& functionName) const {
+ std::lock_guard lock(mutex_);
+
+ auto it = functionHandlers_.find(functionName);
+ if (it != functionHandlers_.end()) {
+ return it->second;
+ }
+ return nullptr;
+}
+
+bool RestFunctionRegistry::hasFunction(const std::string& functionName) const {
+ std::lock_guard lock(mutex_);
+ return functionHandlers_.find(functionName) != functionHandlers_.end();
+}
+
+void RestFunctionRegistry::clear() {
+ std::lock_guard lock(mutex_);
+ functionHandlers_.clear();
+}
+
+} // namespace facebook::presto::functions::rest
diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RestFunctionRegistry.h b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RestFunctionRegistry.h
new file mode 100644
index 000000000000..4284e3c08670
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/RestFunctionRegistry.h
@@ -0,0 +1,77 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include
+#include
+#include
+
+#include "presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestHandler.h"
+
+namespace facebook::presto::functions::rest {
+
+/// @brief Registry for remote function REST handlers.
+/// Provides a centralized location for registering and looking up function
+/// handlers by name. This follows the same pattern as WindowFunctionRegistry
+/// in Velox.
+class RestFunctionRegistry {
+ public:
+ /// Returns the singleton instance of the registry.
+ static RestFunctionRegistry& getInstance();
+
+ /// Registers a function handler for a given function name.
+ /// If a handler with the same name already exists, it will be replaced.
+ /// @param functionName The name of the function to register
+ /// @param handler The handler implementation for the function
+ /// @return true if a handler was replaced, false if this is a new
+ /// registration
+ bool registerFunction(
+ const std::string& functionName,
+ std::shared_ptr handler);
+
+ /// Unregisters a function handler by name.
+ /// @param functionName The name of the function to unregister
+ /// @return true if a handler was found and removed, false otherwise
+ bool unregisterFunction(const std::string& functionName);
+
+ /// Looks up a function handler by name.
+ /// @param functionName The name of the function to look up
+ /// @return The handler if found, nullptr otherwise
+ std::shared_ptr getFunction(
+ const std::string& functionName) const;
+
+ /// Checks if a function is registered.
+ /// @param functionName The name of the function to check
+ /// @return true if the function is registered, false otherwise
+ bool hasFunction(const std::string& functionName) const;
+
+ /// Clears all registered functions.
+ /// Useful for testing scenarios.
+ void clear();
+
+ private:
+ RestFunctionRegistry() = default;
+ ~RestFunctionRegistry() = default;
+
+ // Prevent copying and assignment
+ RestFunctionRegistry(const RestFunctionRegistry&) = delete;
+ RestFunctionRegistry& operator=(const RestFunctionRegistry&) = delete;
+
+ mutable std::mutex mutex_;
+ std::unordered_map>
+ functionHandlers_;
+};
+
+} // namespace facebook::presto::functions::rest
diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteDoubleDivHandler.h b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteDoubleDivHandler.h
new file mode 100644
index 000000000000..b55c4641ab46
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteDoubleDivHandler.h
@@ -0,0 +1,61 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestHandler.h"
+
+namespace facebook::presto::functions::rest {
+
+class RemoteDoubleDivHandler : public RemoteFunctionRestHandler {
+ public:
+ RemoteDoubleDivHandler() = default;
+
+ velox::RowTypePtr getInputTypes() const override {
+ return velox::ROW({"c0", "c1"}, {velox::DOUBLE(), velox::DOUBLE()});
+ }
+
+ velox::TypePtr getOutputType() const override {
+ return velox::DOUBLE();
+ }
+
+ protected:
+ void compute(
+ const velox::RowVectorPtr& inputVector,
+ const velox::VectorPtr& resultVector,
+ std::string& /*errorMessage*/) override {
+ auto numerator = inputVector->childAt(0)->asFlatVector();
+ auto denominator = inputVector->childAt(1)->asFlatVector();
+ auto outFlat = resultVector->asFlatVector();
+
+ const auto numRows = inputVector->size();
+ for (velox::vector_size_t i = 0; i < numRows; ++i) {
+ // If either input is null, output is null.
+ if (numerator->isNullAt(i) || denominator->isNullAt(i)) {
+ outFlat->setNull(i, true);
+ } else {
+ double numVal = numerator->valueAt(i);
+ double denVal = denominator->valueAt(i);
+ // If denominator is zero, produce a null.
+ if (denVal == 0.0) {
+ outFlat->setNull(i, true);
+ } else {
+ outFlat->set(i, numVal / denVal);
+ }
+ }
+ }
+ }
+};
+
+} // namespace facebook::presto::functions::rest
diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteFibonacciHandler.h b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteFibonacciHandler.h
new file mode 100644
index 000000000000..8de2bd969404
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteFibonacciHandler.h
@@ -0,0 +1,54 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include
+#include "presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestHandler.h"
+
+namespace facebook::presto::functions::rest {
+
+class RemoteFibonacciHandler : public RemoteFunctionRestHandler {
+ public:
+ RemoteFibonacciHandler() = default;
+
+ velox::RowTypePtr getInputTypes() const override {
+ return velox::ROW({"c0"}, {velox::BIGINT()});
+ }
+
+ velox::TypePtr getOutputType() const override {
+ return velox::BIGINT();
+ }
+
+ protected:
+ void compute(
+ const velox::RowVectorPtr& inputVector,
+ const velox::VectorPtr& resultVector,
+ std::string& /*errorMessage*/) override {
+ auto numFlat = inputVector->childAt(0)->asFlatVector();
+ auto outFlat = resultVector->asFlatVector();
+
+ const auto numRows = inputVector->size();
+ for (velox::vector_size_t i = 0; i < numRows; ++i) {
+ if (numFlat->isNullAt(i)) {
+ outFlat->setNull(i, true);
+ } else {
+ int64_t num = numFlat->valueAt(i);
+ outFlat->set(i, boost::math::fibonacci(num));
+ }
+ }
+ }
+};
+
+} // namespace facebook::presto::functions::rest
diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteInverseCdfHandler.h b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteInverseCdfHandler.h
new file mode 100644
index 000000000000..bd65d34de129
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteInverseCdfHandler.h
@@ -0,0 +1,76 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include
+#include "presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestHandler.h"
+
+namespace facebook::presto::functions::rest {
+namespace {
+inline double inverse_chi_squared_cdf(double p, double nu) {
+ if (p <= 0.0 || p >= 1.0) {
+ throw std::domain_error("inverse_chi_squared_cdf: p must be in (0,1)");
+ }
+ if (nu <= 0.0) {
+ throw std::domain_error(
+ "inverse_chi_squared_cdf: degrees of freedom must be > 0");
+ }
+
+ const boost::math::chi_squared_distribution chi2(nu);
+ double result = boost::math::quantile(chi2, p);
+ return std::round(result * 100.0) / 100.0;
+}
+} // namespace
+
+class RemoteInverseCdfHandler : public RemoteFunctionRestHandler {
+ public:
+ RemoteInverseCdfHandler() = default;
+
+ velox::RowTypePtr getInputTypes() const override {
+ return velox::ROW({"c0", "c1"}, {velox::DOUBLE(), velox::DOUBLE()});
+ }
+
+ velox::TypePtr getOutputType() const override {
+ return velox::DOUBLE();
+ }
+
+ protected:
+ void compute(
+ const velox::RowVectorPtr& inputVector,
+ const velox::VectorPtr& resultVector,
+ std::string& errorMessage) override {
+ auto p = inputVector->childAt(0)->asFlatVector();
+ auto nu = inputVector->childAt(1)->asFlatVector();
+ auto outFlat = resultVector->asFlatVector();
+
+ const auto numRows = inputVector->size();
+ for (velox::vector_size_t i = 0; i < numRows; ++i) {
+ // If either input is null, output is null.
+ if (p->isNullAt(i) || nu->isNullAt(i)) {
+ outFlat->setNull(i, true);
+ } else {
+ try {
+ double pVal = p->valueAt(i);
+ double nuVal = nu->valueAt(i);
+ outFlat->set(i, inverse_chi_squared_cdf(pVal, nuVal));
+ } catch (const std::domain_error& ex) {
+ errorMessage = ex.what();
+ }
+ }
+ }
+ }
+};
+
+} // namespace facebook::presto::functions::rest
diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteRemoveCharHandler.h b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteRemoveCharHandler.h
new file mode 100644
index 000000000000..b79818925812
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteRemoveCharHandler.h
@@ -0,0 +1,64 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestHandler.h"
+
+namespace facebook::presto::functions::rest {
+
+class RemoteRemoveCharHandler final : public RemoteFunctionRestHandler {
+ public:
+ RemoteRemoveCharHandler() = default;
+
+ velox::RowTypePtr getInputTypes() const override {
+ return velox::ROW({"c0", "c1"}, {velox::VARCHAR(), velox::VARCHAR()});
+ }
+
+ velox::TypePtr getOutputType() const override {
+ return velox::VARCHAR();
+ }
+
+ protected:
+ void compute(
+ const velox::RowVectorPtr& inputVector,
+ const velox::VectorPtr& resultVector,
+ std::string& /*errorMessage*/) override {
+ auto inputFlat = inputVector->childAt(0)->asFlatVector();
+ auto removeFlat =
+ inputVector->childAt(1)->asFlatVector();
+ auto outFlat = resultVector->asFlatVector();
+
+ const auto numRows = inputVector->size();
+
+ for (velox::vector_size_t i = 0; i < numRows; ++i) {
+ if (inputFlat->isNullAt(i) || removeFlat->isNullAt(i)) {
+ outFlat->setNull(i, true);
+ continue;
+ }
+ std::string src(
+ inputFlat->valueAt(i).data(), inputFlat->valueAt(i).size());
+ const auto removeView = removeFlat->valueAt(i);
+ if (removeView.empty()) {
+ outFlat->set(i, velox::StringView(src));
+ continue;
+ }
+ const char ch = removeView.data()[0];
+ src.erase(std::remove(src.begin(), src.end(), ch), src.end());
+ outFlat->set(i, velox::StringView(src));
+ }
+ }
+};
+
+} // namespace facebook::presto::functions::rest
diff --git a/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteStrLenHandler.h b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteStrLenHandler.h
new file mode 100644
index 000000000000..980a50937048
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/functions/remote/tests/server/examples/RemoteStrLenHandler.h
@@ -0,0 +1,52 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "presto_cpp/main/functions/remote/tests/server/RemoteFunctionRestHandler.h"
+
+namespace facebook::presto::functions::rest {
+class RemoteStrLenHandler : public RemoteFunctionRestHandler {
+ public:
+ RemoteStrLenHandler() = default;
+
+ velox::RowTypePtr getInputTypes() const override {
+ return velox::ROW({"c0"}, {velox::VARCHAR()});
+ }
+
+ velox::TypePtr getOutputType() const override {
+ return velox::INTEGER();
+ }
+
+ protected:
+ void compute(
+ const velox::RowVectorPtr& inputVector,
+ const velox::VectorPtr& resultVector,
+ std::string& errorMessage) override {
+ auto inputFlat = inputVector->childAt(0)->asFlatVector();
+ auto outFlat = resultVector->asFlatVector();
+ const auto numRows = inputVector->size();
+
+ for (velox::vector_size_t i = 0; i < numRows; ++i) {
+ if (inputFlat->isNullAt(i)) {
+ outFlat->setNull(i, true);
+ } else {
+ int32_t stringLen = inputFlat->valueAt(i).size();
+ outFlat->set(i, stringLen);
+ }
+ }
+ }
+};
+
+} // namespace facebook::presto::functions::rest
diff --git a/presto-native-execution/presto_cpp/main/functions/remote/utils/ContentTypes.h b/presto-native-execution/presto_cpp/main/functions/remote/utils/ContentTypes.h
new file mode 100644
index 000000000000..4200ea74fb67
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/functions/remote/utils/ContentTypes.h
@@ -0,0 +1,22 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+namespace facebook::presto::functions::remote {
+inline constexpr const char* CONTENT_TYPE_SPARK_UNSAFE_ROW =
+ "application/X-spark-unsafe-row";
+inline constexpr const char* CONTENT_TYPE_PRESTO_PAGE =
+ "application/X-presto-pages";
+} // namespace facebook::presto::functions::remote
diff --git a/presto-native-execution/presto_cpp/main/http/HttpConstants.h b/presto-native-execution/presto_cpp/main/http/HttpConstants.h
index b6e3c24909a9..bc91594efce7 100644
--- a/presto-native-execution/presto_cpp/main/http/HttpConstants.h
+++ b/presto-native-execution/presto_cpp/main/http/HttpConstants.h
@@ -17,6 +17,7 @@ namespace facebook::presto::http {
const uint16_t kHttpOk = 200;
const uint16_t kHttpAccepted = 202;
const uint16_t kHttpNoContent = 204;
+const uint16_t kHttpMultipleChoices = 300;
const uint16_t kHttpBadRequest = 400;
const uint16_t kHttpUnauthorized = 401;
const uint16_t kHttpNotFound = 404;
diff --git a/presto-native-execution/presto_cpp/main/types/CMakeLists.txt b/presto-native-execution/presto_cpp/main/types/CMakeLists.txt
index 8f11b6308353..b58b1ac7a7b8 100644
--- a/presto-native-execution/presto_cpp/main/types/CMakeLists.txt
+++ b/presto-native-execution/presto_cpp/main/types/CMakeLists.txt
@@ -16,6 +16,10 @@ target_link_libraries(presto_type_converter velox_presto_type_parser)
add_library(presto_velox_expr_conversion OBJECT PrestoToVeloxExpr.cpp)
target_link_libraries(presto_velox_expr_conversion velox_presto_types velox_vector velox_exception)
+if(PRESTO_ENABLE_REMOTE_FUNCTIONS)
+ target_link_libraries(presto_velox_expr_conversion presto_to_velox_remote_functions)
+endif()
+
add_library(presto_types PrestoToVeloxQueryPlan.cpp VeloxPlanValidator.cpp PrestoToVeloxSplit.cpp)
target_link_libraries(
presto_types
diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp
index b83f12ca0e36..2aecd724cd77 100644
--- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp
+++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp
@@ -14,7 +14,7 @@
#include "presto_cpp/main/types/PrestoToVeloxExpr.h"
#include
-#include
+#include
#include "presto_cpp/main/common/Configs.h"
#include "presto_cpp/main/common/Utils.h"
#include "presto_cpp/presto_protocol/Base64Util.h"
@@ -23,6 +23,9 @@
#include "velox/vector/ComplexVector.h"
#include "velox/vector/ConstantVector.h"
#include "velox/vector/FlatVector.h"
+#ifdef PRESTO_ENABLE_REMOTE_FUNCTIONS
+#include "presto_cpp/main/functions/remote/PrestoRestFunctionRegistration.h"
+#endif
using namespace facebook::velox::core;
using facebook::velox::TypeKind;
@@ -516,6 +519,19 @@ TypedExprPtr VeloxExprConverter::toVeloxExpr(
return std::make_shared(
returnType, args, getFunctionName(sqlFunctionHandle->functionId));
}
+#ifdef PRESTO_ENABLE_REMOTE_FUNCTIONS
+ else if (
+ auto restFunctionHandle =
+ std::dynamic_pointer_cast(
+ pexpr.functionHandle)) {
+ auto args = toVeloxExpr(pexpr.arguments);
+ auto returnType = typeParser_->parse(pexpr.returnType);
+
+ functions::remote::rest::registerRestRemoteFunction(*restFunctionHandle);
+ return std::make_shared(
+ returnType, args, getFunctionName(restFunctionHandle->functionId));
+ }
+#endif
VELOX_FAIL("Unsupported function handle: {}", pexpr.functionHandle->_type);
}
diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h
index f63a84ec35ad..bbd0ae54ee81 100644
--- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h
+++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h
@@ -13,7 +13,6 @@
*/
#pragma once
-#include
#include "presto_cpp/main/types/TypeParser.h"
#include "presto_cpp/presto_protocol/core/presto_protocol_core.h"
#include "velox/core/Expressions.h"
diff --git a/presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt b/presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt
index 0e1d27d1b18a..25ca93ca91be 100644
--- a/presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt
+++ b/presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt
@@ -57,33 +57,22 @@ target_link_libraries(
GTest::gtest
GTest::gtest_main
presto_connectors
- $
- $
- $
+ presto_protocol
+ presto_type_converter
+ presto_types
presto_operators
presto_mutable_configs
presto_type_test_utils
presto_session_properties
velox_core
- velox_dwio_common_exception
- velox_encode
velox_exec
velox_exec_test_lib
velox_functions_prestosql
- velox_functions_lib
+ velox_presto_type_parser
velox_hive_connector
velox_tpch_connector
- velox_hive_partition_function
- velox_presto_serializer
- velox_presto_type_parser
- velox_serialization
velox_type
Boost::filesystem
- ${RE2}
- ${FOLLY_WITH_DEPENDENCIES}
- ${GLOG}
- ${GFLAGS_LIBRARIES}
- pthread
)
set_property(TARGET presto_expressions_test PROPERTY JOB_POOL_LINK presto_link_job_pool)
@@ -133,3 +122,25 @@ target_link_libraries(
GTest::gtest
GTest::gtest_main
)
+
+# RemoteFunctionHandleTest.cpp only contains tests cases related to
+# RestFunctionHandle, therefore it is only enabled when remote functions are
+# enabled.
+if(PRESTO_ENABLE_REMOTE_FUNCTIONS)
+ add_executable(presto_to_rest_function_test RestFunctionHandleTest.cpp)
+
+ add_test(
+ NAME presto_to_rest_function_test
+ COMMAND presto_to_rest_function_test
+ WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
+ )
+
+ target_link_libraries(
+ presto_to_rest_function_test
+ presto_velox_expr_conversion
+ presto_protocol
+ presto_types
+ GTest::gtest
+ GTest::gtest_main
+ )
+endif()
diff --git a/presto-native-execution/presto_cpp/main/types/tests/RestFunctionHandleTest.cpp b/presto-native-execution/presto_cpp/main/types/tests/RestFunctionHandleTest.cpp
new file mode 100644
index 000000000000..f02bd9ba2c39
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/types/tests/RestFunctionHandleTest.cpp
@@ -0,0 +1,451 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include
+#include "presto_cpp/main/common/Configs.h"
+#include "presto_cpp/main/types/PrestoToVeloxExpr.h"
+#include "velox/functions/remote/client/Remote.h"
+#include "velox/functions/remote/server/RemoteFunctionService.h"
+
+using namespace facebook::presto;
+using namespace facebook::velox;
+
+class RestFunctionHandleTest : public ::testing::Test {
+ protected:
+ static void SetUpTestCase() {
+ memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{});
+ }
+
+ void setupCallExpression() {
+ memoryPool_ = memory::MemoryManager::getInstance()->addLeafPool();
+ converter_ =
+ std::make_unique(memoryPool_.get(), &typeParser_);
+
+ expectedMetadata_.serdeFormat = functions::remote::PageFormat::PRESTO_PAGE;
+
+ testExpr_.returnType = "bigint";
+ testExpr_.displayName = "testFunction";
+ auto cexpr = std::make_shared();
+ cexpr->type = "bigint";
+ cexpr->valueBlock.data = "CgAAAExPTkdfQVJSQVkBAAAAAAEAAAAAAAAA";
+ testExpr_.arguments.push_back(cexpr);
+
+ auto cexpr2 = std::make_shared();
+ cexpr2->type = "bigint";
+ cexpr2->valueBlock.data = "CgAAAExPTkdfQVJSQVkBAAAAAAEAAAAAAAAA";
+ testExpr_.arguments.push_back(cexpr2);
+ }
+
+ void SetUp() override {
+ auto restConfig = restSystemConfig();
+ auto systemConfig = SystemConfig::instance();
+ systemConfig->initialize(std::move(restConfig));
+ setupCallExpression();
+ }
+
+ static std::unique_ptr restSystemConfig(
+ const std::unordered_map& configOverride = {}) {
+ std::unordered_map systemConfig{
+ {std::string(SystemConfig::kRemoteFunctionServerSerde),
+ std::string("presto_page")},
+ {std::string(SystemConfig::kRemoteFunctionServerRestURL),
+ std::string("http://localhost:8080")}};
+
+ for (const auto& [configName, configValue] : configOverride) {
+ systemConfig[configName] = configValue;
+ }
+ return std::make_unique(std::move(systemConfig), true);
+ }
+
+ std::shared_ptr functionHandle_;
+ protocol::CallExpression testExpr_;
+ functions::RemoteVectorFunctionMetadata expectedMetadata_;
+ std::shared_ptr memoryPool_;
+ TypeParser typeParser_;
+ std::unique_ptr converter_;
+};
+
+TEST_F(RestFunctionHandleTest, parseRestFunctionHandle) {
+ try {
+ const std::string str = R"(
+ {
+ "@type": "RestFunctionHandle",
+ "functionId": "remote.testSchema.testFunction;BIGINT;BIGINT",
+ "version": "v1",
+ "signature": {
+ "name": "testFunction",
+ "kind": "SCALAR",
+ "returnType": "bigint",
+ "argumentTypes": ["bigint", "bigint"],
+ "typeVariableConstraints": [],
+ "longVariableConstraints": [],
+ "variableArity": false
+ }
+ }
+ )";
+ const json j = json::parse(str);
+ const std::shared_ptr restFunctionHandle = j;
+ testExpr_.functionHandle = restFunctionHandle;
+
+ auto expr = converter_->toVeloxExpr(testExpr_);
+ auto callExpr = std::dynamic_pointer_cast(expr);
+ ASSERT_NE(callExpr, nullptr);
+ EXPECT_EQ(callExpr->name(), "remote.testSchema.testFunction");
+
+ EXPECT_EQ(callExpr->inputs().size(), 2);
+ auto arg0 = std::dynamic_pointer_cast(
+ callExpr->inputs()[0]);
+ auto arg1 = std::dynamic_pointer_cast(
+ callExpr->inputs()[1]);
+ ASSERT_NE(arg0, nullptr);
+ ASSERT_NE(arg1, nullptr);
+ EXPECT_EQ(arg0->type()->kind(), TypeKind::BIGINT);
+ EXPECT_EQ(arg1->type()->kind(), TypeKind::BIGINT);
+
+ } catch (const std::exception& e) {
+ FAIL() << "Exception: " << e.what();
+ }
+}
+
+TEST_F(RestFunctionHandleTest, parseRestFunctionHandleWithDecimalType) {
+ try {
+ const std::string str = R"JSON(
+{
+ "@type": "RestFunctionHandle",
+ "functionId": "remote.testSchema.decimalFunction;decimal(10,2);decimal(10,2)",
+ "version": "v1",
+ "signature": {
+ "name": "decimalFunction",
+ "kind": "SCALAR",
+ "returnType": "decimal(10,2)",
+ "argumentTypes": ["decimal(10,2)", "decimal(10,2)"],
+ "typeVariableConstraints": [],
+ "longVariableConstraints": [],
+ "variableArity": false
+ }
+}
+)JSON";
+
+ const json j = json::parse(str);
+ const std::shared_ptr restFunctionHandle = j;
+
+ protocol::CallExpression callExpr;
+ callExpr.returnType = "decimal(10,2)";
+ callExpr.displayName = "decimalFunction";
+ callExpr.functionHandle = restFunctionHandle;
+
+ auto cexpr = std::make_shared();
+ cexpr->type = "decimal(10,2)";
+ cexpr->valueBlock.data = "CgAAAExPTkdfQVJSQVkBAAAAAAEAAAAAAAAA";
+ callExpr.arguments.push_back(cexpr);
+
+ auto cexpr2 = std::make_shared();
+ cexpr2->type = "decimal(10,2)";
+ cexpr2->valueBlock.data = "CgAAAExPTkdfQVJSQVkBAAAAAAEAAAAAAAAA";
+ callExpr.arguments.push_back(cexpr2);
+
+ auto expr = converter_->toVeloxExpr(callExpr);
+ auto typedCallExpr =
+ std::dynamic_pointer_cast(expr);
+ ASSERT_NE(typedCallExpr, nullptr);
+ EXPECT_EQ(typedCallExpr->name(), "remote.testSchema.decimalFunction");
+ EXPECT_EQ(typedCallExpr->type()->kind(), TypeKind::BIGINT);
+
+ EXPECT_EQ(typedCallExpr->inputs().size(), 2);
+ auto arg0 = std::dynamic_pointer_cast(
+ typedCallExpr->inputs()[0]);
+ auto arg1 = std::dynamic_pointer_cast(
+ typedCallExpr->inputs()[1]);
+ ASSERT_NE(arg0, nullptr);
+ ASSERT_NE(arg1, nullptr);
+ EXPECT_EQ(arg0->type()->kind(), TypeKind::BIGINT);
+ EXPECT_EQ(arg1->type()->kind(), TypeKind::BIGINT);
+
+ } catch (const std::exception& e) {
+ FAIL() << "Exception: " << e.what();
+ }
+}
+
+TEST_F(RestFunctionHandleTest, parseRestFunctionHandleWithArrayType) {
+ try {
+ const std::string str = R"JSON(
+ {
+ "@type": "RestFunctionHandle",
+ "functionId": "remote.testSchema.arrayFunction;array(bigint);array(varchar)",
+ "version": "v1",
+ "signature": {
+ "name": "arrayFunction",
+ "kind": "SCALAR",
+ "returnType": "array(bigint)",
+ "argumentTypes": ["array(bigint)", "array(varchar)"],
+ "typeVariableConstraints": [],
+ "longVariableConstraints": [],
+ "variableArity": false
+ }
+ }
+ )JSON";
+ const json j = json::parse(str);
+ const std::shared_ptr restFunctionHandle = j;
+
+ // Verify the signature parsing
+ ASSERT_NE(restFunctionHandle, nullptr);
+ EXPECT_EQ(
+ restFunctionHandle->functionId,
+ "remote.testSchema.arrayFunction;array(bigint);array(varchar)");
+ EXPECT_EQ(restFunctionHandle->signature.name, "arrayFunction");
+ EXPECT_EQ(restFunctionHandle->signature.returnType, "array(bigint)");
+ EXPECT_EQ(restFunctionHandle->signature.argumentTypes.size(), 2);
+ EXPECT_EQ(restFunctionHandle->signature.argumentTypes[0], "array(bigint)");
+ EXPECT_EQ(restFunctionHandle->signature.argumentTypes[1], "array(varchar)");
+
+ // Verify type parsing
+ auto returnType =
+ typeParser_.parse(restFunctionHandle->signature.returnType);
+ EXPECT_EQ(returnType->kind(), TypeKind::ARRAY);
+ auto returnArrayType =
+ std::dynamic_pointer_cast(returnType);
+ ASSERT_NE(returnArrayType, nullptr);
+ EXPECT_EQ(returnArrayType->elementType()->kind(), TypeKind::BIGINT);
+
+ auto argType0 =
+ typeParser_.parse(restFunctionHandle->signature.argumentTypes[0]);
+ EXPECT_EQ(argType0->kind(), TypeKind::ARRAY);
+ auto argArrayType0 = std::dynamic_pointer_cast(argType0);
+ ASSERT_NE(argArrayType0, nullptr);
+ EXPECT_EQ(argArrayType0->elementType()->kind(), TypeKind::BIGINT);
+
+ auto argType1 =
+ typeParser_.parse(restFunctionHandle->signature.argumentTypes[1]);
+ EXPECT_EQ(argType1->kind(), TypeKind::ARRAY);
+ auto argArrayType1 = std::dynamic_pointer_cast(argType1);
+ ASSERT_NE(argArrayType1, nullptr);
+ EXPECT_EQ(argArrayType1->elementType()->kind(), TypeKind::VARCHAR);
+
+ } catch (const std::exception& e) {
+ FAIL() << "Exception: " << e.what();
+ }
+}
+
+TEST_F(RestFunctionHandleTest, parseRestFunctionHandleWithMapType) {
+ try {
+ const std::string str = R"JSON(
+ {
+ "@type": "RestFunctionHandle",
+ "functionId": "remote.testSchema.mapFunction;map(varchar,bigint);map(bigint,double)",
+ "version": "v1",
+ "signature": {
+ "name": "mapFunction",
+ "kind": "SCALAR",
+ "returnType": "map(varchar,bigint)",
+ "argumentTypes": ["map(varchar,bigint)", "map(bigint,double)"],
+ "typeVariableConstraints": [],
+ "longVariableConstraints": [],
+ "variableArity": false
+ }
+ }
+ )JSON";
+ const json j = json::parse(str);
+ const std::shared_ptr restFunctionHandle = j;
+
+ // Verify the signature parsing
+ ASSERT_NE(restFunctionHandle, nullptr);
+ EXPECT_EQ(
+ restFunctionHandle->functionId,
+ "remote.testSchema.mapFunction;map(varchar,bigint);map(bigint,double)");
+ EXPECT_EQ(restFunctionHandle->signature.name, "mapFunction");
+ EXPECT_EQ(restFunctionHandle->signature.returnType, "map(varchar,bigint)");
+ EXPECT_EQ(restFunctionHandle->signature.argumentTypes.size(), 2);
+ EXPECT_EQ(
+ restFunctionHandle->signature.argumentTypes[0], "map(varchar,bigint)");
+ EXPECT_EQ(
+ restFunctionHandle->signature.argumentTypes[1], "map(bigint,double)");
+
+ // Verify type parsing
+ auto returnType =
+ typeParser_.parse(restFunctionHandle->signature.returnType);
+ EXPECT_EQ(returnType->kind(), TypeKind::MAP);
+ auto returnMapType = std::dynamic_pointer_cast(returnType);
+ ASSERT_NE(returnMapType, nullptr);
+ EXPECT_EQ(returnMapType->keyType()->kind(), TypeKind::VARCHAR);
+ EXPECT_EQ(returnMapType->valueType()->kind(), TypeKind::BIGINT);
+
+ auto argType0 =
+ typeParser_.parse(restFunctionHandle->signature.argumentTypes[0]);
+ EXPECT_EQ(argType0->kind(), TypeKind::MAP);
+ auto argMapType0 = std::dynamic_pointer_cast(argType0);
+ ASSERT_NE(argMapType0, nullptr);
+ EXPECT_EQ(argMapType0->keyType()->kind(), TypeKind::VARCHAR);
+ EXPECT_EQ(argMapType0->valueType()->kind(), TypeKind::BIGINT);
+
+ auto argType1 =
+ typeParser_.parse(restFunctionHandle->signature.argumentTypes[1]);
+ EXPECT_EQ(argType1->kind(), TypeKind::MAP);
+ auto argMapType1 = std::dynamic_pointer_cast(argType1);
+ ASSERT_NE(argMapType1, nullptr);
+ EXPECT_EQ(argMapType1->keyType()->kind(), TypeKind::BIGINT);
+ EXPECT_EQ(argMapType1->valueType()->kind(), TypeKind::DOUBLE);
+
+ } catch (const std::exception& e) {
+ FAIL() << "Exception: " << e.what();
+ }
+}
+
+TEST_F(RestFunctionHandleTest, parseRestFunctionHandleWithRowType) {
+ try {
+ const std::string str = R"JSON(
+ {
+ "@type": "RestFunctionHandle",
+ "functionId": "remote.testSchema.rowFunction;row(bigint,varchar);row(double,boolean)",
+ "version": "v1",
+ "signature": {
+ "name": "rowFunction",
+ "kind": "SCALAR",
+ "returnType": "row(bigint,varchar)",
+ "argumentTypes": ["row(bigint,varchar)", "row(double,boolean)"],
+ "typeVariableConstraints": [],
+ "longVariableConstraints": [],
+ "variableArity": false
+ }
+ }
+ )JSON";
+ const json j = json::parse(str);
+ const std::shared_ptr restFunctionHandle = j;
+
+ // Verify the signature parsing
+ ASSERT_NE(restFunctionHandle, nullptr);
+ EXPECT_EQ(
+ restFunctionHandle->functionId,
+ "remote.testSchema.rowFunction;row(bigint,varchar);row(double,boolean)");
+ EXPECT_EQ(restFunctionHandle->signature.name, "rowFunction");
+ EXPECT_EQ(restFunctionHandle->signature.returnType, "row(bigint,varchar)");
+ EXPECT_EQ(restFunctionHandle->signature.argumentTypes.size(), 2);
+ EXPECT_EQ(
+ restFunctionHandle->signature.argumentTypes[0], "row(bigint,varchar)");
+ EXPECT_EQ(
+ restFunctionHandle->signature.argumentTypes[1], "row(double,boolean)");
+
+ // Verify type parsing
+ auto returnType =
+ typeParser_.parse(restFunctionHandle->signature.returnType);
+ EXPECT_EQ(returnType->kind(), TypeKind::ROW);
+ auto returnRowType = std::dynamic_pointer_cast(returnType);
+ ASSERT_NE(returnRowType, nullptr);
+ EXPECT_EQ(returnRowType->size(), 2);
+ EXPECT_EQ(returnRowType->childAt(0)->kind(), TypeKind::BIGINT);
+ EXPECT_EQ(returnRowType->childAt(1)->kind(), TypeKind::VARCHAR);
+
+ auto argType0 =
+ typeParser_.parse(restFunctionHandle->signature.argumentTypes[0]);
+ EXPECT_EQ(argType0->kind(), TypeKind::ROW);
+ auto argRowType0 = std::dynamic_pointer_cast(argType0);
+ ASSERT_NE(argRowType0, nullptr);
+ EXPECT_EQ(argRowType0->size(), 2);
+ EXPECT_EQ(argRowType0->childAt(0)->kind(), TypeKind::BIGINT);
+ EXPECT_EQ(argRowType0->childAt(1)->kind(), TypeKind::VARCHAR);
+
+ auto argType1 =
+ typeParser_.parse(restFunctionHandle->signature.argumentTypes[1]);
+ EXPECT_EQ(argType1->kind(), TypeKind::ROW);
+ auto argRowType1 = std::dynamic_pointer_cast(argType1);
+ ASSERT_NE(argRowType1, nullptr);
+ EXPECT_EQ(argRowType1->size(), 2);
+ EXPECT_EQ(argRowType1->childAt(0)->kind(), TypeKind::DOUBLE);
+ EXPECT_EQ(argRowType1->childAt(1)->kind(), TypeKind::BOOLEAN);
+
+ } catch (const std::exception& e) {
+ FAIL() << "Exception: " << e.what();
+ }
+}
+
+TEST_F(RestFunctionHandleTest, parseRestFunctionHandleWithNestedComplexTypes) {
+ try {
+ const std::string str = R"JSON(
+ {
+ "@type": "RestFunctionHandle",
+ "functionId": "remote.testSchema.nestedFunction;array(map(varchar,bigint));row(array(decimal(10,2)),map(bigint,varchar))",
+ "version": "v1",
+ "signature": {
+ "name": "nestedFunction",
+ "kind": "SCALAR",
+ "returnType": "map(varchar,array(bigint))",
+ "argumentTypes": ["array(map(varchar,bigint))", "row(array(decimal(10,2)),map(bigint,varchar))"],
+ "typeVariableConstraints": [],
+ "longVariableConstraints": [],
+ "variableArity": false
+ }
+ }
+ )JSON";
+ const json j = json::parse(str);
+ const std::shared_ptr restFunctionHandle = j;
+
+ // Verify the signature parsing
+ ASSERT_NE(restFunctionHandle, nullptr);
+ EXPECT_EQ(restFunctionHandle->signature.name, "nestedFunction");
+ EXPECT_EQ(
+ restFunctionHandle->signature.returnType, "map(varchar,array(bigint))");
+ EXPECT_EQ(restFunctionHandle->signature.argumentTypes.size(), 2);
+
+ // Verify return type: map(varchar,array(bigint))
+ auto returnType =
+ typeParser_.parse(restFunctionHandle->signature.returnType);
+ EXPECT_EQ(returnType->kind(), TypeKind::MAP);
+ auto returnMapType = std::dynamic_pointer_cast(returnType);
+ ASSERT_NE(returnMapType, nullptr);
+ EXPECT_EQ(returnMapType->keyType()->kind(), TypeKind::VARCHAR);
+ EXPECT_EQ(returnMapType->valueType()->kind(), TypeKind::ARRAY);
+ auto valueArrayType =
+ std::dynamic_pointer_cast(returnMapType->valueType());
+ ASSERT_NE(valueArrayType, nullptr);
+ EXPECT_EQ(valueArrayType->elementType()->kind(), TypeKind::BIGINT);
+
+ // Verify arg0 type: array(map(varchar,bigint))
+ auto argType0 =
+ typeParser_.parse(restFunctionHandle->signature.argumentTypes[0]);
+ EXPECT_EQ(argType0->kind(), TypeKind::ARRAY);
+ auto argArrayType = std::dynamic_pointer_cast(argType0);
+ ASSERT_NE(argArrayType, nullptr);
+ EXPECT_EQ(argArrayType->elementType()->kind(), TypeKind::MAP);
+ auto elementMapType =
+ std::dynamic_pointer_cast(argArrayType->elementType());
+ ASSERT_NE(elementMapType, nullptr);
+ EXPECT_EQ(elementMapType->keyType()->kind(), TypeKind::VARCHAR);
+ EXPECT_EQ(elementMapType->valueType()->kind(), TypeKind::BIGINT);
+
+ // Verify arg1 type: row(array(decimal(10,2)),map(bigint,varchar))
+ auto argType1 =
+ typeParser_.parse(restFunctionHandle->signature.argumentTypes[1]);
+ EXPECT_EQ(argType1->kind(), TypeKind::ROW);
+ auto argRowType = std::dynamic_pointer_cast(argType1);
+ ASSERT_NE(argRowType, nullptr);
+ EXPECT_EQ(argRowType->size(), 2);
+
+ // First child: array(decimal(10,2))
+ EXPECT_EQ(argRowType->childAt(0)->kind(), TypeKind::ARRAY);
+ auto childArrayType =
+ std::dynamic_pointer_cast(argRowType->childAt(0));
+ ASSERT_NE(childArrayType, nullptr);
+ EXPECT_EQ(childArrayType->elementType()->kind(), TypeKind::BIGINT);
+
+ // Second child: map(bigint,varchar)
+ EXPECT_EQ(argRowType->childAt(1)->kind(), TypeKind::MAP);
+ auto childMapType =
+ std::dynamic_pointer_cast(argRowType->childAt(1));
+ ASSERT_NE(childMapType, nullptr);
+ EXPECT_EQ(childMapType->keyType()->kind(), TypeKind::BIGINT);
+ EXPECT_EQ(childMapType->valueType()->kind(), TypeKind::VARCHAR);
+
+ } catch (const std::exception& e) {
+ FAIL() << "Exception: " << e.what();
+ }
+}
diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/ContainerQueryRunner.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/ContainerQueryRunner.java
index 99da6faace55..b07e6112c7b5 100644
--- a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/ContainerQueryRunner.java
+++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/ContainerQueryRunner.java
@@ -28,6 +28,7 @@
import com.facebook.presto.sql.planner.NodePartitioningManager;
import com.facebook.presto.sql.planner.sanity.PlanCheckerProviderManager;
import com.facebook.presto.testing.MaterializedResult;
+import com.facebook.presto.testing.MaterializedRow;
import com.facebook.presto.testing.QueryRunner;
import com.facebook.presto.testing.TestingAccessControlManager;
import com.facebook.presto.transaction.TransactionManager;
@@ -38,6 +39,7 @@
import java.io.IOException;
import java.sql.Connection;
+import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
@@ -50,46 +52,57 @@
import java.util.logging.Logger;
import static com.facebook.presto.testing.TestingSession.testSessionBuilder;
-import static java.sql.DriverManager.getConnection;
public class ContainerQueryRunner
implements QueryRunner
{
- private static final Network network = Network.newNetwork();
- private static final String PRESTO_COORDINATOR_IMAGE = System.getProperty("coordinatorImage", "presto-coordinator:latest");
- private static final String PRESTO_WORKER_IMAGE = System.getProperty("workerImage", "presto-worker:latest");
- private static final String CONTAINER_TIMEOUT = System.getProperty("containerTimeout", "120");
- private static final String CLUSTER_SHUTDOWN_TIMEOUT = System.getProperty("clusterShutDownTimeout", "10");
- private static final String BASE_DIR = System.getProperty("user.dir");
- private static final int DEFAULT_COORDINATOR_PORT = 8080;
- private static final String TPCH_CATALOG = "tpch";
- private static final String TINY_SCHEMA = "tiny";
- private static final int DEFAULT_NUMBER_OF_WORKERS = 4;
- private static final Logger logger = Logger.getLogger(ContainerQueryRunner.class.getName());
- private final GenericContainer> coordinator;
- private final List> workers = new ArrayList<>();
- private final int coordinatorPort;
- private final String catalog;
- private final String schema;
- private final int numberOfWorkers;
- private Connection connection;
+ protected static final Network network = Network.newNetwork();
+ protected static final String PRESTO_COORDINATOR_IMAGE = System.getProperty("coordinatorImage", "presto-coordinator:latest");
+ protected static final String PRESTO_WORKER_IMAGE = System.getProperty("workerImage", "presto-worker:latest");
+ protected static final String CONTAINER_TIMEOUT = System.getProperty("containerTimeout", "120");
+ protected static final String CLUSTER_SHUTDOWN_TIMEOUT = System.getProperty("clusterShutDownTimeout", "10");
+ protected static final String BASE_DIR = System.getProperty("user.dir");
+ protected static final int DEFAULT_COORDINATOR_PORT = 8080;
+ protected static final int DEFAULT_FUNCTION_SERVER_PORT = 1122;
+ protected static final String TPCH_CATALOG = "tpch";
+ protected static final String TINY_SCHEMA = "tiny";
+ protected static final int DEFAULT_NUMBER_OF_WORKERS = 4;
+
+ protected static final Logger logger = Logger.getLogger(ContainerQueryRunner.class.getName());
+
+ protected final GenericContainer> coordinator;
+ protected final List> workers = new ArrayList<>();
+ protected final int coordinatorPort;
+ protected final String catalog;
+ protected final String schema;
+ protected GenericContainer> functionServer;
+ protected int functionServerPort;
+ protected boolean enableFunctionServer;
+ protected Connection connection;
public ContainerQueryRunner()
throws InterruptedException, IOException
{
- this(DEFAULT_COORDINATOR_PORT, TPCH_CATALOG, TINY_SCHEMA, DEFAULT_NUMBER_OF_WORKERS);
+ this(DEFAULT_COORDINATOR_PORT, TPCH_CATALOG, TINY_SCHEMA, DEFAULT_NUMBER_OF_WORKERS, DEFAULT_FUNCTION_SERVER_PORT, false);
}
- public ContainerQueryRunner(int coordinatorPort, String catalog, String schema, int numberOfWorkers)
+ public ContainerQueryRunner(int coordinatorPort, String catalog, String schema, int numberOfWorkers, int functionServerPort, boolean enableFunctionServer)
throws InterruptedException, IOException
{
this.coordinatorPort = coordinatorPort;
this.catalog = catalog;
this.schema = schema;
- this.numberOfWorkers = numberOfWorkers;
+ this.functionServerPort = functionServerPort;
+ this.enableFunctionServer = enableFunctionServer;
+
+ // Start function server first if enabled
+ if (enableFunctionServer) {
+ this.functionServer = createFunctionServer();
+ this.functionServer.start();
+ logger.info("Presto function server is deployed at http://" + functionServer.getHost() + ":" + functionServer.getMappedPort(functionServerPort));
+ }
- // The container details can be added as properties in VM options for testing in IntelliJ.
- coordinator = createCoordinator();
+ this.coordinator = createCoordinator();
for (int i = 0; i < numberOfWorkers; i++) {
workers.add(createNativeWorker(7777 + i, "native-worker-" + i));
}
@@ -107,10 +120,10 @@ public ContainerQueryRunner(int coordinatorPort, String catalog, String schema,
coordinator.getMappedPort(coordinatorPort),
catalog,
schema,
- "timeZoneId=UTC");
+ enableFunctionServer ? "timeZoneId=UTC&sessionProperties=remote_functions_enabled:true" : "timeZoneId=UTC");
try {
- connection = getConnection(url, "test", null);
+ this.connection = DriverManager.getConnection(url, "test", null);
}
catch (SQLException e) {
throw new RuntimeException(e);
@@ -118,12 +131,16 @@ public ContainerQueryRunner(int coordinatorPort, String catalog, String schema,
// Delete the temporary files once the containers are started.
ContainerQueryRunnerUtils.deleteDirectory(BASE_DIR + "/testcontainers/coordinator");
- for (int i = 0; i < numberOfWorkers; i++) {
- ContainerQueryRunnerUtils.deleteDirectory(BASE_DIR + "/testcontainers/native-worker-" + i);
+ for (GenericContainer> worker : workers) {
+ String alias = worker.getNetworkAliases().get(1);
+ ContainerQueryRunnerUtils.deleteDirectory(BASE_DIR + "/testcontainers/" + alias);
+ }
+ if (enableFunctionServer) {
+ ContainerQueryRunnerUtils.deleteDirectory(BASE_DIR + "/testcontainers/function-server");
}
}
- private GenericContainer> createCoordinator()
+ protected GenericContainer> createCoordinator()
throws IOException
{
ContainerQueryRunnerUtils.createCoordinatorTpchProperties();
@@ -132,22 +149,25 @@ private GenericContainer> createCoordinator()
ContainerQueryRunnerUtils.createCoordinatorJvmConfig();
ContainerQueryRunnerUtils.createCoordinatorLogProperties();
ContainerQueryRunnerUtils.createCoordinatorNodeProperties();
- ContainerQueryRunnerUtils.createCoordinatorEntryPointScript();
+ ContainerQueryRunnerUtils.createCoordinatorEntryPointScript(); // Never run function server in coordinator
+ if (enableFunctionServer) {
+ ContainerQueryRunnerUtils.createRestRemoteProperties(functionServerPort);
+ }
return new GenericContainer<>(PRESTO_COORDINATOR_IMAGE)
- .withExposedPorts(coordinatorPort)
.withNetwork(network)
.withNetworkAliases("presto-coordinator")
.withCopyFileToContainer(MountableFile.forHostPath(BASE_DIR + "/testcontainers/coordinator/etc"), "/opt/presto-server/etc")
.withCopyFileToContainer(MountableFile.forHostPath(BASE_DIR + "/testcontainers/coordinator/entrypoint.sh"), "/opt/entrypoint.sh")
.waitingFor(Wait.forLogMessage(".*======== SERVER STARTED ========.*", 1))
- .withStartupTimeout(Duration.ofSeconds(Long.parseLong(CONTAINER_TIMEOUT)));
+ .withStartupTimeout(Duration.ofSeconds(Long.parseLong(CONTAINER_TIMEOUT)))
+ .withExposedPorts(coordinatorPort);
}
- private GenericContainer> createNativeWorker(int port, String nodeId)
+ protected GenericContainer> createNativeWorker(int port, String nodeId)
throws IOException
{
- ContainerQueryRunnerUtils.createNativeWorkerConfigProperties(coordinatorPort, nodeId);
+ ContainerQueryRunnerUtils.createNativeWorkerConfigPropertiesWithFunctionServer(coordinatorPort, functionServerPort, nodeId);
ContainerQueryRunnerUtils.createNativeWorkerTpchProperties(nodeId);
ContainerQueryRunnerUtils.createNativeWorkerEntryPointScript(nodeId);
ContainerQueryRunnerUtils.createNativeWorkerNodeProperties(nodeId);
@@ -160,6 +180,23 @@ private GenericContainer> createNativeWorker(int port, String nodeId)
.waitingFor(Wait.forLogMessage(".*Announcement succeeded: HTTP 202.*", 1));
}
+ protected GenericContainer> createFunctionServer()
+ throws IOException
+ {
+ ContainerQueryRunnerUtils.createFunctionServerConfigProperties(functionServerPort);
+ ContainerQueryRunnerUtils.createFunctionServerEntryPointScript();
+
+ // Reuse the coordinator image since it already contains the function server jar
+ return new GenericContainer<>(PRESTO_COORDINATOR_IMAGE)
+ .withNetwork(network)
+ .withNetworkAliases("presto-remote-function-server")
+ .withCopyFileToContainer(MountableFile.forHostPath(BASE_DIR + "/testcontainers/function-server/etc"), "/opt/function-server/etc")
+ .withCopyFileToContainer(MountableFile.forHostPath(BASE_DIR + "/testcontainers/function-server/entrypoint.sh"), "/opt/entrypoint.sh")
+ .waitingFor(Wait.forLogMessage(".*======== REMOTE FUNCTION SERVER STARTED at: .*", 1))
+ .withStartupTimeout(Duration.ofSeconds(Long.parseLong(CONTAINER_TIMEOUT)))
+ .withExposedPorts(functionServerPort);
+ }
+
@Override
public void close()
{
@@ -171,6 +208,9 @@ public void close()
}
coordinator.stop();
workers.forEach(GenericContainer::stop);
+ if (functionServer != null) {
+ functionServer.stop();
+ }
}
@Override
@@ -248,7 +288,27 @@ public MaterializedResult execute(String sql)
@Override
public MaterializedResult execute(Session session, String sql, List extends Type> resultTypes)
{
- throw new UnsupportedOperationException();
+ // Added logic similar to H2QueryRunner.
+ try {
+ Statement statement = connection.createStatement();
+ ResultSet resultSet = statement.executeQuery(sql);
+ MaterializedResult rawResult = ContainerQueryRunnerUtils.toMaterializedResult(resultSet);
+
+ // Coerce the raw result to the requested resultTypes
+ List coercedRows = new ArrayList<>();
+ for (MaterializedRow row : rawResult.getMaterializedRows()) {
+ List