Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,17 @@
using facebook::velox::functions::remote::PageFormat;

namespace facebook::presto::functions::remote::rest {
namespace {
// Returns the serialization/deserialization format used by the remote function
// server. The format is determined by the system configuration value
// "remoteFunctionServerSerde". Supported formats:
// - "presto_page": Uses Presto page format.
// - "spark_unsafe_row": Uses Spark unsafe row format.
// @return PageFormat enum value corresponding to the configured serde format.
// @throws VeloxException if the configured format is unknown.
PageFormat getSerdeFormat() {

PrestoRestFunctionRegistration::PrestoRestFunctionRegistration()
: kRemoteFunctionServerRestURL_(
SystemConfig::instance()->remoteFunctionServerRestURL()) {}

PrestoRestFunctionRegistration& PrestoRestFunctionRegistration::getInstance() {
static PrestoRestFunctionRegistration instance;
return instance;
}

PageFormat PrestoRestFunctionRegistration::getSerdeFormat() {
static const auto serdeFormat =
SystemConfig::instance()->remoteFunctionServerSerde();
if (serdeFormat == "presto_page") {
Expand All @@ -46,31 +48,22 @@ PageFormat getSerdeFormat() {
}
}

// Encodes a string for safe inclusion in a URL by escaping non-alphanumeric
// characters using percent-encoding. Alphanumeric characters and '-', '_', '.',
// '~' are left unchanged. All other characters are replaced with '%' followed
// by their two-digit hexadecimal value.
// @param value The input string to encode.
// @return The URL-encoded string.
std::string urlEncode(const std::string& value) {
std::string PrestoRestFunctionRegistration::urlEncode(
const std::string& value) {
return boost::urls::encode(value, boost::urls::unreserved_chars);
}

std::string getFunctionName(const protocol::SqlFunctionId& functionId) {
std::string PrestoRestFunctionRegistration::getFunctionName(
const protocol::SqlFunctionId& functionId) {
// Example: "namespace.schema.function;TYPE;TYPE".
const auto nameEnd = functionId.find(';');
// Assuming the possibility of missing ';' if there are no function arguments.
return nameEnd != std::string::npos ? functionId.substr(0, nameEnd)
: functionId;
}

// Constructs a Velox function signature from a Presto function signature. This
// function translates type variable constraints, integer variable constraints,
// return type, argument types, and variable arity from the Presto signature to
// the corresponding Velox signature builder.
// @param prestoSignature The Presto function signature to convert.
// @return A pointer to the constructed Velox function signature.
velox::exec::FunctionSignaturePtr buildVeloxSignatureFromPrestoSignature(
velox::exec::FunctionSignaturePtr
PrestoRestFunctionRegistration::buildVeloxSignatureFromPrestoSignature(
const protocol::Signature& prestoSignature) {
velox::exec::FunctionSignatureBuilder signatureBuilder;

Expand All @@ -93,48 +86,48 @@ velox::exec::FunctionSignaturePtr buildVeloxSignatureFromPrestoSignature(
return signatureBuilder.build();
}

} // namespace
std::string PrestoRestFunctionRegistration::getRemoteFunctionServerUrl(
const protocol::RestFunctionHandle& restFunctionHandle) const {
if (restFunctionHandle.executionEndpoint &&
!restFunctionHandle.executionEndpoint->empty()) {
return *restFunctionHandle.executionEndpoint;
}
return kRemoteFunctionServerRestURL_;
}

void registerRestRemoteFunction(
void PrestoRestFunctionRegistration::registerFunction(
const protocol::RestFunctionHandle& restFunctionHandle) {
static std::mutex registrationMutex;
static std::unordered_map<std::string, std::string> registeredFunctionHandles;
static std::unordered_map<std::string, functions::rest::RestRemoteClientPtr>
restClient;
static const std::string remoteFunctionServerRestURL =
SystemConfig::instance()->remoteFunctionServerRestURL();

const std::string functionId = restFunctionHandle.functionId;

const std::string remoteFunctionServerRestURL =
getRemoteFunctionServerUrl(restFunctionHandle);
json functionHandleJson;
to_json(functionHandleJson, restFunctionHandle);
functionHandleJson["url"] = remoteFunctionServerRestURL;
const std::string serializedFunctionHandle = functionHandleJson.dump();

// Check if already registered (read-only, no lock needed for initial check)
{
std::lock_guard<std::mutex> lock(registrationMutex);
auto it = registeredFunctionHandles.find(functionId);
if (it != registeredFunctionHandles.end() &&
std::lock_guard<std::mutex> lock(registrationMutex_);
auto it = registeredFunctionHandles_.find(functionId);
if (it != registeredFunctionHandles_.end() &&
it->second == serializedFunctionHandle) {
return;
}
}

// Get or create shared RestRemoteClient for this server URL
functions::rest::RestRemoteClientPtr remoteClient;
RestRemoteClientPtr remoteClient;
{
std::lock_guard<std::mutex> lock(registrationMutex);
auto clientIt = restClient.find(remoteFunctionServerRestURL);
if (clientIt == restClient.end()) {
restClient[remoteFunctionServerRestURL] =
std::make_shared<functions::rest::RestRemoteClient>(
remoteFunctionServerRestURL);
std::lock_guard<std::mutex> lock(registrationMutex_);
auto clientIt = restClients_.find(remoteFunctionServerRestURL);
if (clientIt == restClients_.end()) {
restClients_[remoteFunctionServerRestURL] =
std::make_shared<RestRemoteClient>(remoteFunctionServerRestURL);
}
remoteClient = restClient[remoteFunctionServerRestURL];
remoteClient = restClients_[remoteFunctionServerRestURL];
}

functions::rest::VeloxRemoteFunctionMetadata metadata;
VeloxRemoteFunctionMetadata metadata;

// Extract function name parts using the utility function
const std::string functionName =
Expand All @@ -158,16 +151,16 @@ void registerRestRemoteFunction(
std::vector<velox::exec::FunctionSignaturePtr> veloxSignatures = {
veloxSignature};

functions::rest::registerVeloxRemoteFunction(
registerVeloxRemoteFunction(
getFunctionName(restFunctionHandle.functionId),
veloxSignatures,
metadata,
remoteClient);

// Update registration map
{
std::lock_guard<std::mutex> lock(registrationMutex);
registeredFunctionHandles[functionId] = serializedFunctionHandle;
std::lock_guard<std::mutex> lock(registrationMutex_);
registeredFunctionHandles_[functionId] = serializedFunctionHandle;
}
}
} // namespace facebook::presto::functions::remote::rest
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,102 @@

#pragma once

#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>

#include "presto_cpp/main/functions/remote/RestRemoteFunction.h"
#include "presto_cpp/main/functions/remote/client/RestRemoteClient.h"
#include "presto_cpp/presto_protocol/presto_protocol.h"
#include "velox/expression/FunctionSignature.h"

namespace facebook::presto::functions::remote::rest {

void registerRestRemoteFunction(
const protocol::RestFunctionHandle& restFunctionHandle);
/// Manages registration of REST-based remote functions in Velox.
/// This class provides a thread-safe singleton interface for registering
/// remote functions that are accessed via REST API endpoints.
class PrestoRestFunctionRegistration {
public:
/// Returns the singleton instance of the registration manager.
/// @return Reference to the singleton instance.
static PrestoRestFunctionRegistration& getInstance();

/// Registers a REST remote function with Velox.
/// This method is thread-safe and handles duplicate registrations.
/// @param restFunctionHandle The Presto REST function handle containing
/// function metadata, signature, and location information.
void registerFunction(const protocol::RestFunctionHandle& restFunctionHandle);

// Delete copy constructor and assignment operator
PrestoRestFunctionRegistration(const PrestoRestFunctionRegistration&) =
delete;
PrestoRestFunctionRegistration& operator=(
const PrestoRestFunctionRegistration&) = delete;

private:
// Private constructor for singleton pattern.
PrestoRestFunctionRegistration();

// Resolves the remote function server URL from the function handle.
// @param restFunctionHandle The Presto REST function handle that may
// contain an execution endpoint.
// @return The resolved remote function server URL.
std::string getRemoteFunctionServerUrl(
const protocol::RestFunctionHandle& restFunctionHandle) const;

// Returns the serialization/deserialization format used by the remote
// function server.
// @return PageFormat enum value corresponding to the configured serde
// format.
static velox::functions::remote::PageFormat getSerdeFormat();

// Encodes a string for safe inclusion in a URL.
// @param value The input string to encode.
// @return The URL-encoded string.
static std::string urlEncode(const std::string& value);

// Extracts the function name from a function ID.
// @param functionId The SQL function ID.
// @return The function name without type parameters.
static std::string getFunctionName(const protocol::SqlFunctionId& functionId);

// Constructs a Velox function signature from a Presto function signature.
// @param prestoSignature The Presto function signature to convert.
// @return A pointer to the constructed Velox function signature.
static velox::exec::FunctionSignaturePtr
buildVeloxSignatureFromPrestoSignature(
const protocol::Signature& prestoSignature);

// Mutex for thread-safe registration operations.
std::mutex registrationMutex_;

// Map of registered function IDs to their serialized handles.
std::unordered_map<std::string, std::string> registeredFunctionHandles_;

// Map of REST server URLs to their corresponding client instances.
std::unordered_map<std::string, RestRemoteClientPtr> restClients_;

// The base URL for the remote function server REST API.
const std::string kRemoteFunctionServerRestURL_;

VELOX_FRIEND_TEST(
PrestoRestFunctionRegistrationTest,
getRemoteFunctionServerUrlWithExecutionEndpoint);
VELOX_FRIEND_TEST(
PrestoRestFunctionRegistrationTest,
getRemoteFunctionServerUrlWithEmptyExecutionEndpoint);
VELOX_FRIEND_TEST(
PrestoRestFunctionRegistrationTest,
getRemoteFunctionServerUrlWithoutExecutionEndpoint);
VELOX_FRIEND_TEST(
PrestoRestFunctionRegistrationTest,
getRemoteFunctionServerUrlConsistency);
VELOX_FRIEND_TEST(
PrestoRestFunctionRegistrationTest,
getRemoteFunctionServerUrlWithDifferentProtocols);
VELOX_FRIEND_TEST(
PrestoRestFunctionRegistrationTest,
getRemoteFunctionServerUrlWithComplexUrls);
};
} // namespace facebook::presto::functions::remote::rest
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include "velox/functions/remote/client/RemoteVectorFunction.h"

using namespace facebook::velox;
namespace facebook::presto::functions::rest {
namespace facebook::presto::functions::remote::rest {
namespace {

class RestRemoteFunction : public velox::functions::RemoteVectorFunction {
Expand Down Expand Up @@ -100,4 +100,4 @@ void registerVeloxRemoteFunction(
overwrite);
}

} // namespace facebook::presto::functions::rest
} // namespace facebook::presto::functions::remote::rest
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include "presto_cpp/main/functions/remote/client/RestRemoteClient.h"
#include "velox/functions/remote/client/RemoteVectorFunction.h"

namespace facebook::presto::functions::rest {
namespace facebook::presto::functions::remote::rest {

struct VeloxRemoteFunctionMetadata
: public velox::functions::RemoteVectorFunctionMetadata {
Expand All @@ -32,4 +32,4 @@ void registerVeloxRemoteFunction(
RestRemoteClientPtr restClient,
bool overwrite = true);

} // namespace facebook::presto::functions::rest
} // namespace facebook::presto::functions::remote::rest
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

using namespace facebook::velox;

namespace facebook::presto::functions::rest {
namespace facebook::presto::functions::remote::rest {
namespace {
inline std::string getContentType(velox::functions::remote::PageFormat fmt) {
return fmt == velox::functions::remote::PageFormat::SPARK_UNSAFE_ROW
Expand Down Expand Up @@ -110,4 +110,4 @@ std::unique_ptr<folly::IOBuf> RestRemoteClient::invokeFunction(
return nullptr;
}

} // namespace facebook::presto::functions::rest
} // namespace facebook::presto::functions::remote::rest
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include "presto_cpp/main/http/HttpClient.h"
#include "velox/functions/remote/if/gen-cpp2/RemoteFunction_types.h"

namespace facebook::presto::functions::rest {
namespace facebook::presto::functions::remote::rest {

class RestRemoteClient {
public:
Expand Down Expand Up @@ -50,4 +50,4 @@ class RestRemoteClient {

using RestRemoteClientPtr = std::shared_ptr<RestRemoteClient>;

} // namespace facebook::presto::functions::rest
} // namespace facebook::presto::functions::remote::rest
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,16 @@ target_link_libraries(
GTest::gtest
GTest::gtest_main
)

add_executable(presto_rest_function_registration_test PrestoRestFunctionRegistrationTest.cpp)

add_test(presto_rest_function_registration_test presto_rest_function_registration_test)

target_link_libraries(
presto_rest_function_registration_test
presto_to_velox_remote_functions
presto_functions_remote
presto_protocol
GTest::gtest
GTest::gtest_main
)
Loading
Loading