Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/issue 5262 support to create remote udf in sql #802

Merged
Merged
9 changes: 5 additions & 4 deletions src/Interpreters/InterpreterCreateFunctionQuery.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ BlockIO InterpreterCreateFunctionQuery::execute()
bool replace_if_exists = create_function_query.or_replace;

/// proton: starts. Handle javascript UDF
if (create_function_query.isJavaScript())
return handleJavaScriptUDF(throw_if_exists, replace_if_exists);
if (create_function_query.isJavaScript() || create_function_query.isRemote())
return handleUDF(throw_if_exists, replace_if_exists);
/// proton: ends

UserDefinedSQLFunctionFactory::instance().registerFunction(current_context, function_name, query_ptr, throw_if_exists, replace_if_exists);
Expand All @@ -64,16 +64,17 @@ BlockIO InterpreterCreateFunctionQuery::execute()
}

/// proton: starts
BlockIO InterpreterCreateFunctionQuery::handleJavaScriptUDF(bool throw_if_exists, bool replace_if_exists)
BlockIO InterpreterCreateFunctionQuery::handleUDF(bool throw_if_exists, bool replace_if_exists)
{
ASTCreateFunctionQuery & create = query_ptr->as<ASTCreateFunctionQuery &>();
assert(create.isJavaScript());
assert(create.isJavaScript() || create.isRemote());

const auto func_name = create.getFunctionName();
Poco::JSON::Object::Ptr func = create.toJSON();
UserDefinedFunctionFactory::instance().registerFunction(getContext(), func_name, func, throw_if_exists, replace_if_exists);

return {};
}

/// proton: ends
}
2 changes: 1 addition & 1 deletion src/Interpreters/InterpreterCreateFunctionQuery.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class InterpreterCreateFunctionQuery : public IInterpreter, WithMutableContext
ASTPtr query_ptr;

/// proton: starts
BlockIO handleJavaScriptUDF(bool throw_if_exists, bool replace_if_exists);
BlockIO handleUDF(bool throw_if_exists, bool replace_if_exists);
/// proton: ends
};

Expand Down
57 changes: 52 additions & 5 deletions src/Parsers/ASTCreateFunctionQuery.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
#include <Parsers/ASTFunction.h>

/// proton: starts
#include <Parsers/formatAST.h>
#include <Parsers/ASTNameTypePair.h>
#include <Parsers/ASTLiteral.h>
#include <Parsers/ASTNameTypePair.h>
#include <Parsers/formatAST.h>

#include <cassert>
#include <boost/algorithm/string/case_conv.hpp>
/// proton: ends


Expand Down Expand Up @@ -35,8 +38,11 @@ void ASTCreateFunctionQuery::formatImpl(const IAST::FormatSettings & settings, I
settings.ostr << "OR REPLACE ";

/// proton: starts
bool is_remote = isRemote();
if (is_aggregation)
settings.ostr << "AGGREGATE FUNCTION ";
else if (is_remote)
settings.ostr << "REMOTE FUNCTION ";
else
settings.ostr << "FUNCTION ";
/// proton: ends
Expand All @@ -50,7 +56,7 @@ void ASTCreateFunctionQuery::formatImpl(const IAST::FormatSettings & settings, I

/// proton: starts
bool is_javascript_func = isJavaScript();
if (is_javascript_func)
if (is_javascript_func || is_remote)
{
/// arguments
arguments->formatImpl(settings, state, frame);
Expand All @@ -63,6 +69,21 @@ void ASTCreateFunctionQuery::formatImpl(const IAST::FormatSettings & settings, I

formatOnCluster(settings);

/// proton: starts
if (is_remote)
{
settings.ostr << fmt::format("\nURL '{}'\n", function_core->as<ASTLiteral>()->value.safeGet<String>());
auto auth_method
= !function_core->children.empty() ? function_core->children[0]->as<ASTLiteral>()->value.safeGet<String>() : "none";
settings.ostr << fmt::format("AUTH_METHOD '{}'\n", auth_method);
if (auth_method != "none")
{
settings.ostr << fmt::format("AUTH_HEADER '{}'\n", function_core->children[1]->as<ASTLiteral>()->value.safeGet<String>());
settings.ostr << fmt::format("AUTH_KEY '{}'\n", function_core->children[2]->as<ASTLiteral>()->value.safeGet<String>());
}
return;
}
/// proton: ends
settings.ostr << (settings.hilite ? hilite_keyword : "") << " AS " << (settings.hilite ? hilite_none : "");

/// proton: starts. Do not format the source of JavaScript UDF
Expand All @@ -89,7 +110,8 @@ Poco::JSON::Object::Ptr ASTCreateFunctionQuery::toJSON() const
Poco::JSON::Object::Ptr func = new Poco::JSON::Object(Poco::JSON_PRESERVE_KEY_ORDER);
Poco::JSON::Object::Ptr inner_func = new Poco::JSON::Object(Poco::JSON_PRESERVE_KEY_ORDER);
inner_func->set("name", getFunctionName());
if (!isJavaScript())
bool is_remote = isRemote();
if (!isJavaScript() && !isRemote())
{
WriteBufferFromOwnString source_buf;
formatAST(*function_core, source_buf, false);
Expand All @@ -116,7 +138,9 @@ Poco::JSON::Object::Ptr ASTCreateFunctionQuery::toJSON() const
inner_func->set("arguments", json_args);

/// type
inner_func->set("type", "javascript");
auto type = lang;
boost::to_lower(type);
inner_func->set("type", type);

/// is_aggregation
inner_func->set("is_aggregation", is_aggregation);
Expand All @@ -126,6 +150,29 @@ Poco::JSON::Object::Ptr ASTCreateFunctionQuery::toJSON() const
formatAST(*return_type, return_buf, false);
inner_func->set("return_type", return_buf.str());

/// remote function
if (is_remote)
{
assert(function_core != nullptr);
inner_func->set("url", function_core->as<ASTLiteral>()->value.safeGet<String>());
// auth
if (!function_core->children.empty())
{
auto auth_method = function_core->children[0]->as<ASTLiteral>()->value.safeGet<String>();
inner_func->set("auth_method", auth_method);
if (auth_method == "auth_header")
{
Poco::JSON::Object::Ptr auth_context = new Poco::JSON::Object();
auth_context->set("key_name", function_core->children[1]->as<ASTLiteral>()->value.safeGet<String>());
auth_context->set("key_value", function_core->children[2]->as<ASTLiteral>()->value.safeGet<String>());
inner_func->set("auth_context", auth_context);
}
}
func->set("function", inner_func);
/// Remote function don't have source, return early.
return func;
}

/// source
ASTLiteral * js_src = function_core->as<ASTLiteral>();
inner_func->set("source", js_src->value.safeGet<String>());
Expand Down
3 changes: 3 additions & 0 deletions src/Parsers/ASTCreateFunctionQuery.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ class ASTCreateFunctionQuery : public IAST, public ASTQueryWithOnCluster

/// If it is a JavaScript UDF
bool isJavaScript() const noexcept { return lang == "JavaScript"; }

/// If it is a JavaScript UDF
bool isRemote() const noexcept { return lang == "Remote"; }
/// proton: ends
};

Expand Down
71 changes: 65 additions & 6 deletions src/Parsers/ParserCreateFunctionQuery.cpp
chhtimeplus marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,24 @@
#include <Parsers/ExpressionListParsers.h>

/// proton: starts
#include <Parsers/ASTLiteral.h>
#include <Parsers/ParserCreateQuery.h>
#include <Parsers/Streaming/ParserArguments.h>

#include <Poco/JSON/Object.h>
/// proton: ends


namespace DB
{

chhtimeplus marked this conversation as resolved.
Show resolved Hide resolved
/// proton: starts
namespace ErrorCodes
{
extern const int AGGREGATE_FUNCTION_NOT_APPLICABLE;
extern const int UNKNOWN_FUNCTION;
}
/// proton: ends
bool ParserCreateFunctionQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Expected & expected, [[ maybe_unused ]] bool hint)
{
ParserKeyword s_create("CREATE");
Expand All @@ -25,6 +35,16 @@ bool ParserCreateFunctionQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Exp
ParserKeyword s_aggr_function("AGGREGATE FUNCTION");
ParserKeyword s_returns("RETURNS");
ParserKeyword s_javascript_type("LANGUAGE JAVASCRIPT");
ParserKeyword s_remote("REMOTE FUNCTION");
ParserKeyword s_url("URL");
ParserKeyword s_auth_method("AUTH_METHOD");
ParserKeyword s_auth_header("AUTH_HEADER");
ParserKeyword s_auth_key("AUTH_KEY");
ParserLiteral value;
ASTPtr url;
ASTPtr auth_method;
ASTPtr auth_header;
ASTPtr auth_key;
ParserArguments arguments_p;
ParserDataType return_p;
ParserStringLiteral js_src_p;
Expand All @@ -46,6 +66,7 @@ bool ParserCreateFunctionQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Exp
bool is_aggregation = false;
bool is_javascript_func = false;
bool is_new_syntax = false;
bool is_remote = false;
/// proton: ends

String cluster_str;
Expand All @@ -61,10 +82,12 @@ bool ParserCreateFunctionQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Exp
/// proton: starts
if (!s_function.ignore(pos, expected))
{
if(!s_aggr_function.ignore(pos, expected))
if (s_aggr_function.ignore(pos, expected))
is_aggregation = true;
else if (s_remote.ignore(pos, expected))
is_remote = true;
else
return false;

is_aggregation = true;
}
/// proton: ends

Expand All @@ -88,13 +111,13 @@ bool ParserCreateFunctionQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Exp
/// proton: starts
if (is_new_syntax && s_returns.ignore(pos, expected))
{
if(!return_p.parse(pos, return_type, expected))
if (!return_p.parse(pos, return_type, expected))
return false;

if (s_javascript_type.ignore(pos, expected))
is_javascript_func = true;

if (!s_as.ignore(pos, expected))
if (!is_remote && !s_as.ignore(pos, expected))
return false;

/// Parse source code and function_core will be 'ASTLiteral'
Expand All @@ -110,6 +133,42 @@ bool ParserCreateFunctionQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Exp
if (!lambda_p.parse(pos, function_core, expected))
return false;
}
if (is_remote)
{
if (is_aggregation)
{
throw Exception("Remote udf can not be an aggregate function", ErrorCodes::AGGREGATE_FUNCTION_NOT_APPLICABLE);
}
if (!s_url.ignore(pos, expected))
return false;
if (!value.parse(pos, url, expected))
return false;
if (s_auth_method.ignore(pos, expected))
{
if (!value.parse(pos, auth_method, expected))
chhtimeplus marked this conversation as resolved.
Show resolved Hide resolved
return false;
auto method_str = auth_method->as<ASTLiteral>()->value.safeGet<String>();
url->children.push_back(std::move(auth_method));
if (method_str == "auth_header")
{
if (!s_auth_header.ignore(pos, expected))
return false;
if (!value.parse(pos, auth_header, expected))
return false;
if (!s_auth_key.ignore(pos, expected))
return false;
if (!value.parse(pos, auth_key, expected))
return false;
url->children.push_back(std::move(auth_header));
url->children.push_back(std::move(auth_key));
}
else if (method_str != "none")
{
throw Exception("AUTH_METHOD must be 'none' or 'auth_header'", ErrorCodes::UNKNOWN_FUNCTION);
}
}
function_core = std::move(url);
}
/// proton: ends

auto create_function_query = std::make_shared<ASTCreateFunctionQuery>();
Expand All @@ -127,7 +186,7 @@ bool ParserCreateFunctionQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Exp

/// proton: starts
create_function_query->is_aggregation = is_aggregation;
create_function_query->lang = is_javascript_func ? "JavaScript" : "SQL";
create_function_query->lang = is_javascript_func ? "JavaScript" : is_remote ? "Remote" : "SQL";
create_function_query->arguments = std::move(arguments);
create_function_query->return_type = std::move(return_type);
/// proton: ends
Expand Down
Loading
Loading