Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/develop' into feature/issue-794-…
Browse files Browse the repository at this point in the history
…pin-ec2-type
  • Loading branch information
yokofly committed Jul 18, 2024
2 parents 9c1a601 + 9d4978b commit f34344c
Show file tree
Hide file tree
Showing 7 changed files with 371 additions and 16 deletions.
9 changes: 5 additions & 4 deletions src/Interpreters/InterpreterCreateFunctionQuery.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ BlockIO InterpreterCreateFunctionQuery::execute()
bool replace_if_exists = create_function_query.or_replace;

/// proton: starts. Handle javascript UDF
if (create_function_query.isJavaScript())
return handleJavaScriptUDF(throw_if_exists, replace_if_exists);
if (create_function_query.isJavaScript() || create_function_query.isRemote())
return handleUDF(throw_if_exists, replace_if_exists);
/// proton: ends

UserDefinedSQLFunctionFactory::instance().registerFunction(current_context, function_name, query_ptr, throw_if_exists, replace_if_exists);
Expand All @@ -64,16 +64,17 @@ BlockIO InterpreterCreateFunctionQuery::execute()
}

/// proton: starts
BlockIO InterpreterCreateFunctionQuery::handleJavaScriptUDF(bool throw_if_exists, bool replace_if_exists)
BlockIO InterpreterCreateFunctionQuery::handleUDF(bool throw_if_exists, bool replace_if_exists)
{
ASTCreateFunctionQuery & create = query_ptr->as<ASTCreateFunctionQuery &>();
assert(create.isJavaScript());
assert(create.isJavaScript() || create.isRemote());

const auto func_name = create.getFunctionName();
Poco::JSON::Object::Ptr func = create.toJSON();
UserDefinedFunctionFactory::instance().registerFunction(getContext(), func_name, func, throw_if_exists, replace_if_exists);

return {};
}

/// proton: ends
}
2 changes: 1 addition & 1 deletion src/Interpreters/InterpreterCreateFunctionQuery.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class InterpreterCreateFunctionQuery : public IInterpreter, WithMutableContext
ASTPtr query_ptr;

/// proton: starts
BlockIO handleJavaScriptUDF(bool throw_if_exists, bool replace_if_exists);
BlockIO handleUDF(bool throw_if_exists, bool replace_if_exists);
/// proton: ends
};

Expand Down
57 changes: 52 additions & 5 deletions src/Parsers/ASTCreateFunctionQuery.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
#include <Parsers/ASTFunction.h>

/// proton: starts
#include <Parsers/formatAST.h>
#include <Parsers/ASTNameTypePair.h>
#include <Parsers/ASTLiteral.h>
#include <Parsers/ASTNameTypePair.h>
#include <Parsers/formatAST.h>

#include <cassert>
#include <boost/algorithm/string/case_conv.hpp>
/// proton: ends


Expand Down Expand Up @@ -35,8 +38,11 @@ void ASTCreateFunctionQuery::formatImpl(const IAST::FormatSettings & settings, I
settings.ostr << "OR REPLACE ";

/// proton: starts
bool is_remote = isRemote();
if (is_aggregation)
settings.ostr << "AGGREGATE FUNCTION ";
else if (is_remote)
settings.ostr << "REMOTE FUNCTION ";
else
settings.ostr << "FUNCTION ";
/// proton: ends
Expand All @@ -50,7 +56,7 @@ void ASTCreateFunctionQuery::formatImpl(const IAST::FormatSettings & settings, I

/// proton: starts
bool is_javascript_func = isJavaScript();
if (is_javascript_func)
if (is_javascript_func || is_remote)
{
/// arguments
arguments->formatImpl(settings, state, frame);
Expand All @@ -63,6 +69,21 @@ void ASTCreateFunctionQuery::formatImpl(const IAST::FormatSettings & settings, I

formatOnCluster(settings);

/// proton: starts
if (is_remote)
{
settings.ostr << fmt::format("\nURL '{}'\n", function_core->as<ASTLiteral>()->value.safeGet<String>());
auto auth_method
= !function_core->children.empty() ? function_core->children[0]->as<ASTLiteral>()->value.safeGet<String>() : "none";
settings.ostr << fmt::format("AUTH_METHOD '{}'\n", auth_method);
if (auth_method != "none")
{
settings.ostr << fmt::format("AUTH_HEADER '{}'\n", function_core->children[1]->as<ASTLiteral>()->value.safeGet<String>());
settings.ostr << fmt::format("AUTH_KEY '{}'\n", function_core->children[2]->as<ASTLiteral>()->value.safeGet<String>());
}
return;
}
/// proton: ends
settings.ostr << (settings.hilite ? hilite_keyword : "") << " AS " << (settings.hilite ? hilite_none : "");

/// proton: starts. Do not format the source of JavaScript UDF
Expand All @@ -89,7 +110,8 @@ Poco::JSON::Object::Ptr ASTCreateFunctionQuery::toJSON() const
Poco::JSON::Object::Ptr func = new Poco::JSON::Object(Poco::JSON_PRESERVE_KEY_ORDER);
Poco::JSON::Object::Ptr inner_func = new Poco::JSON::Object(Poco::JSON_PRESERVE_KEY_ORDER);
inner_func->set("name", getFunctionName());
if (!isJavaScript())
bool is_remote = isRemote();
if (!isJavaScript() && !isRemote())
{
WriteBufferFromOwnString source_buf;
formatAST(*function_core, source_buf, false);
Expand All @@ -116,7 +138,9 @@ Poco::JSON::Object::Ptr ASTCreateFunctionQuery::toJSON() const
inner_func->set("arguments", json_args);

/// type
inner_func->set("type", "javascript");
auto type = lang;
boost::to_lower(type);
inner_func->set("type", type);

/// is_aggregation
inner_func->set("is_aggregation", is_aggregation);
Expand All @@ -126,6 +150,29 @@ Poco::JSON::Object::Ptr ASTCreateFunctionQuery::toJSON() const
formatAST(*return_type, return_buf, false);
inner_func->set("return_type", return_buf.str());

/// remote function
if (is_remote)
{
assert(function_core != nullptr);
inner_func->set("url", function_core->as<ASTLiteral>()->value.safeGet<String>());
// auth
if (!function_core->children.empty())
{
auto auth_method = function_core->children[0]->as<ASTLiteral>()->value.safeGet<String>();
inner_func->set("auth_method", auth_method);
if (auth_method == "auth_header")
{
Poco::JSON::Object::Ptr auth_context = new Poco::JSON::Object();
auth_context->set("key_name", function_core->children[1]->as<ASTLiteral>()->value.safeGet<String>());
auth_context->set("key_value", function_core->children[2]->as<ASTLiteral>()->value.safeGet<String>());
inner_func->set("auth_context", auth_context);
}
}
func->set("function", inner_func);
/// Remote function don't have source, return early.
return func;
}

/// source
ASTLiteral * js_src = function_core->as<ASTLiteral>();
inner_func->set("source", js_src->value.safeGet<String>());
Expand Down
3 changes: 3 additions & 0 deletions src/Parsers/ASTCreateFunctionQuery.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ class ASTCreateFunctionQuery : public IAST, public ASTQueryWithOnCluster

/// If it is a JavaScript UDF
bool isJavaScript() const noexcept { return lang == "JavaScript"; }

/// If it is a JavaScript UDF
bool isRemote() const noexcept { return lang == "Remote"; }
/// proton: ends
};

Expand Down
71 changes: 65 additions & 6 deletions src/Parsers/ParserCreateFunctionQuery.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,24 @@
#include <Parsers/ExpressionListParsers.h>

/// proton: starts
#include <Parsers/ASTLiteral.h>
#include <Parsers/ParserCreateQuery.h>
#include <Parsers/Streaming/ParserArguments.h>

#include <Poco/JSON/Object.h>
/// proton: ends


namespace DB
{

/// proton: starts
namespace ErrorCodes
{
extern const int AGGREGATE_FUNCTION_NOT_APPLICABLE;
extern const int UNKNOWN_FUNCTION;
}
/// proton: ends
bool ParserCreateFunctionQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Expected & expected, [[ maybe_unused ]] bool hint)
{
ParserKeyword s_create("CREATE");
Expand All @@ -25,6 +35,16 @@ bool ParserCreateFunctionQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Exp
ParserKeyword s_aggr_function("AGGREGATE FUNCTION");
ParserKeyword s_returns("RETURNS");
ParserKeyword s_javascript_type("LANGUAGE JAVASCRIPT");
ParserKeyword s_remote("REMOTE FUNCTION");
ParserKeyword s_url("URL");
ParserKeyword s_auth_method("AUTH_METHOD");
ParserKeyword s_auth_header("AUTH_HEADER");
ParserKeyword s_auth_key("AUTH_KEY");
ParserLiteral value;
ASTPtr url;
ASTPtr auth_method;
ASTPtr auth_header;
ASTPtr auth_key;
ParserArguments arguments_p;
ParserDataType return_p;
ParserStringLiteral js_src_p;
Expand All @@ -46,6 +66,7 @@ bool ParserCreateFunctionQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Exp
bool is_aggregation = false;
bool is_javascript_func = false;
bool is_new_syntax = false;
bool is_remote = false;
/// proton: ends

String cluster_str;
Expand All @@ -61,10 +82,12 @@ bool ParserCreateFunctionQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Exp
/// proton: starts
if (!s_function.ignore(pos, expected))
{
if(!s_aggr_function.ignore(pos, expected))
if (s_aggr_function.ignore(pos, expected))
is_aggregation = true;
else if (s_remote.ignore(pos, expected))
is_remote = true;
else
return false;

is_aggregation = true;
}
/// proton: ends

Expand All @@ -88,13 +111,13 @@ bool ParserCreateFunctionQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Exp
/// proton: starts
if (is_new_syntax && s_returns.ignore(pos, expected))
{
if(!return_p.parse(pos, return_type, expected))
if (!return_p.parse(pos, return_type, expected))
return false;

if (s_javascript_type.ignore(pos, expected))
is_javascript_func = true;

if (!s_as.ignore(pos, expected))
if (!is_remote && !s_as.ignore(pos, expected))
return false;

/// Parse source code and function_core will be 'ASTLiteral'
Expand All @@ -110,6 +133,42 @@ bool ParserCreateFunctionQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Exp
if (!lambda_p.parse(pos, function_core, expected))
return false;
}
if (is_remote)
{
if (is_aggregation)
{
throw Exception("Remote udf can not be an aggregate function", ErrorCodes::AGGREGATE_FUNCTION_NOT_APPLICABLE);
}
if (!s_url.ignore(pos, expected))
return false;
if (!value.parse(pos, url, expected))
return false;
if (s_auth_method.ignore(pos, expected))
{
if (!value.parse(pos, auth_method, expected))
return false;
auto method_str = auth_method->as<ASTLiteral>()->value.safeGet<String>();
url->children.push_back(std::move(auth_method));
if (method_str == "auth_header")
{
if (!s_auth_header.ignore(pos, expected))
return false;
if (!value.parse(pos, auth_header, expected))
return false;
if (!s_auth_key.ignore(pos, expected))
return false;
if (!value.parse(pos, auth_key, expected))
return false;
url->children.push_back(std::move(auth_header));
url->children.push_back(std::move(auth_key));
}
else if (method_str != "none")
{
throw Exception("AUTH_METHOD must be 'none' or 'auth_header'", ErrorCodes::UNKNOWN_FUNCTION);
}
}
function_core = std::move(url);
}
/// proton: ends

auto create_function_query = std::make_shared<ASTCreateFunctionQuery>();
Expand All @@ -127,7 +186,7 @@ bool ParserCreateFunctionQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Exp

/// proton: starts
create_function_query->is_aggregation = is_aggregation;
create_function_query->lang = is_javascript_func ? "JavaScript" : "SQL";
create_function_query->lang = is_javascript_func ? "JavaScript" : is_remote ? "Remote" : "SQL";
create_function_query->arguments = std::move(arguments);
create_function_query->return_type = std::move(return_type);
/// proton: ends
Expand Down
Loading

0 comments on commit f34344c

Please sign in to comment.