Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -520,24 +520,32 @@ TypedExprPtr VeloxExprConverter::toVeloxExpr(
auto returnType = typeParser_->parse(pexpr.returnType);
return std::make_shared<CallTypedExpr>(
returnType, args, getFunctionName(signature));
}

} else if (
auto sqlFunctionHandle =
// Parse args and returnType once for all remaining branches
auto args = toVeloxExpr(pexpr.arguments);
auto returnType = typeParser_->parse(pexpr.returnType);

if (auto sqlFunctionHandle =
std::dynamic_pointer_cast<protocol::SqlFunctionHandle>(
pexpr.functionHandle)) {
auto args = toVeloxExpr(pexpr.arguments);
auto returnType = typeParser_->parse(pexpr.returnType);
return std::make_shared<CallTypedExpr>(
returnType, args, getFunctionName(sqlFunctionHandle->functionId));
}

else if (
auto nativeFunctionHandle =
std::dynamic_pointer_cast<protocol::NativeFunctionHandle>(
pexpr.functionHandle)) {
auto signature = nativeFunctionHandle->signature;
return std::make_shared<CallTypedExpr>(
returnType, args, getFunctionName(signature));
}
#ifdef PRESTO_ENABLE_REMOTE_FUNCTIONS
else if (
auto restFunctionHandle =
std::dynamic_pointer_cast<protocol::RestFunctionHandle>(
pexpr.functionHandle)) {
auto args = toVeloxExpr(pexpr.arguments);
auto returnType = typeParser_->parse(pexpr.returnType);

functions::remote::rest::PrestoRestFunctionRegistration::getInstance()
.registerFunction(*restFunctionHandle);
return std::make_shared<CallTypedExpr>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,21 @@ bool useCachedHashTable(const protocol::PlanNode& node) {
return false;
}

const protocol::Signature* getSignatureFromFunctionHandle(
const std::shared_ptr<protocol::FunctionHandle>& functionHandle) {
if (const auto builtin =
std::dynamic_pointer_cast<protocol::BuiltInFunctionHandle>(
functionHandle)) {
return &builtin->signature;
} else if (
const auto native =
std::dynamic_pointer_cast<protocol::NativeFunctionHandle>(
functionHandle)) {
return &native->signature;
}
return nullptr;
}

std::vector<std::string> getNames(const protocol::Assignments& assignments) {
std::vector<std::string> names;
names.reserve(assignments.assignments.size());
Expand Down Expand Up @@ -933,12 +948,10 @@ void VeloxQueryPlanConverterBase::toAggregations(
aggregate.call = std::dynamic_pointer_cast<const core::CallTypedExpr>(
exprConverter_.toVeloxExpr(prestoAggregation.call));

if (const auto builtin =
std::dynamic_pointer_cast<protocol::BuiltInFunctionHandle>(
prestoAggregation.functionHandle)) {
const auto& signature = builtin->signature;
aggregate.rawInputTypes.reserve(signature.argumentTypes.size());
for (const auto& argumentType : signature.argumentTypes) {
if (const auto signature =
getSignatureFromFunctionHandle(prestoAggregation.functionHandle)) {
aggregate.rawInputTypes.reserve(signature->argumentTypes.size());
for (const auto& argumentType : signature->argumentTypes) {
aggregate.rawInputTypes.push_back(
stringToType(argumentType, typeParser_));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
#include "presto_cpp/main/types/VeloxToPrestoExpr.h"
#include <boost/algorithm/string.hpp>
#include "presto_cpp/main/common/Utils.h"
#include "presto_cpp/main/types/PrestoToVeloxExpr.h"
#include "velox/core/ITypedExpr.h"
#include "velox/expression/ExprConstants.h"
Expand Down Expand Up @@ -96,6 +97,30 @@ const std::unordered_map<std::string, std::string>& veloxToPrestoOperatorMap() {
}
return veloxToPrestoOperatorMap;
}

// If the function name prefix starts from "presto.default", then it is a built
// in function handle. Otherwise, it is a native function handle.
std::shared_ptr<protocol::FunctionHandle> getFunctionHandle(
const std::string& name,
const protocol::Signature& signature) {
static constexpr char const* kStatic = "$static";
static constexpr char const* kNativeFunctionHandle = "native";
static constexpr char const* builtInCatalog = "presto";
static constexpr char const* builtInSchema = "default";

const auto parts = util::getFunctionNameParts(name);
if ((parts[0] == builtInCatalog) && (parts[1] == builtInSchema)) {
auto handle = std::make_shared<protocol::BuiltInFunctionHandle>();
handle->_type = kStatic;
handle->signature = signature;
return handle;
} else {
auto handle = std::make_shared<protocol::NativeFunctionHandle>();
handle->_type = kNativeFunctionHandle;
handle->signature = signature;
return handle;
}
}
} // namespace

std::string VeloxToPrestoExprConverter::getValueBlock(
Expand Down Expand Up @@ -281,10 +306,7 @@ CallExpressionPtr VeloxToPrestoExprConverter::getCallExpression(
signature.argumentTypes = argumentTypes;
signature.variableArity = false;

protocol::BuiltInFunctionHandle builtInFunctionHandle;
builtInFunctionHandle._type = kStatic;
builtInFunctionHandle.signature = signature;
result["functionHandle"] = builtInFunctionHandle;
result["functionHandle"] = getFunctionHandle(exprName, signature);
result["returnType"] = getTypeSignature(expr->type());
result["arguments"] = json::array();
for (const auto& exprInput : exprInputs) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,11 @@ target_link_libraries(
GTest::gtest_main
)

add_executable(presto_to_velox_query_plan_test PrestoToVeloxQueryPlanTest.cpp)
add_executable(
presto_to_velox_query_plan_test
PrestoToVeloxQueryPlanTest.cpp
NativeFunctionHandleTest.cpp
)

add_test(
NAME presto_to_velox_query_plan_test
Expand Down
Loading
Loading