Skip to content

Commit

Permalink
refactor parser and test, fix some comments
Browse files Browse the repository at this point in the history
  • Loading branch information
chhtimeplus committed Jul 15, 2024
1 parent 8640e88 commit 750f022
Show file tree
Hide file tree
Showing 5 changed files with 240 additions and 66 deletions.
48 changes: 30 additions & 18 deletions src/Parsers/ASTCreateFunctionQuery.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#include <Common/quoteString.h>
#include <IO/Operators.h>
#include <Parsers/ASTCreateFunctionQuery.h>
#include <Parsers/ASTExpressionList.h>
#include <Parsers/ASTFunction.h>
#include <Common/quoteString.h>

/// proton: starts
#include <Parsers/formatAST.h>
#include <Parsers/ASTNameTypePair.h>
#include <Parsers/ASTLiteral.h>
#include <Parsers/ASTNameTypePair.h>
#include <Parsers/formatAST.h>

#include <boost/algorithm/string/case_conv.hpp>
/// proton: ends
Expand All @@ -26,7 +26,7 @@ ASTPtr ASTCreateFunctionQuery::clone() const

res->function_core = function_core->clone();
res->children.push_back(res->function_core);
res->payload = payload;
res->remote_func_settings = remote_func_settings;
return res;
}

Expand All @@ -38,8 +38,11 @@ void ASTCreateFunctionQuery::formatImpl(const IAST::FormatSettings & settings, I
settings.ostr << "OR REPLACE ";

/// proton: starts
bool is_remote = isRemote();
if (is_aggregation)
settings.ostr << "AGGREGATE FUNCTION ";
else if (is_remote)
settings.ostr << "REMOTE FUNCTION ";
else
settings.ostr << "FUNCTION ";
/// proton: ends
Expand All @@ -53,7 +56,6 @@ void ASTCreateFunctionQuery::formatImpl(const IAST::FormatSettings & settings, I

/// proton: starts
bool is_javascript_func = isJavaScript();
bool is_remote = isRemote();
if (is_javascript_func || is_remote)
{
/// arguments
Expand All @@ -70,11 +72,17 @@ void ASTCreateFunctionQuery::formatImpl(const IAST::FormatSettings & settings, I
/// proton: starts
if (is_remote)
{
settings.ostr << (settings.hilite ? hilite_keyword : "") << fmt::format("\nTYPE Remote \n") << (settings.hilite ? hilite_none : "");
settings.ostr << fmt::format("URL '{}'\n", payload->get("AUTH_METHOD").toString());
settings.ostr << fmt::format("AUTH_METHOD '{}'\n", payload->has("AUTH_METHOD") ? payload->get("AUTH_METHOD").toString() : "none");
settings.ostr << fmt::format("AUTH_HEADER '{}'\n", payload->has("AUTH_HEADER") ? payload->get("AUTH_HEADER").toString() : "none");
settings.ostr << fmt::format("AUTH_KEY '{}'\n", payload->has("AUTH_KEY") ? payload->get("AUTH_KEY").toString() : "none");
settings.ostr << fmt::format("\nURL '{}'\n", remote_func_settings->get("URL").toString());
auto auth_method = remote_func_settings->has("AUTH_METHOD") ? remote_func_settings->get("AUTH_METHOD").toString() : "none";
settings.ostr << fmt::format("AUTH_METHOD '{}'\n", auth_method);
if (auth_method != "none")
{
settings.ostr << fmt::format(
"AUTH_HEADER '{}'\n",
remote_func_settings->has("AUTH_HEADER") ? remote_func_settings->get("AUTH_HEADER").toString() : "none");
settings.ostr << fmt::format(
"AUTH_KEY '{}'\n", remote_func_settings->has("AUTH_KEY") ? remote_func_settings->get("AUTH_KEY").toString() : "none");
}
return;
}
/// proton: ends
Expand Down Expand Up @@ -144,18 +152,22 @@ Poco::JSON::Object::Ptr ASTCreateFunctionQuery::toJSON() const
formatAST(*return_type, return_buf, false);
inner_func->set("return_type", return_buf.str());

/// remote functio
/// remote function
if (is_remote)
{
inner_func->set("url", payload->get("URL").toString());
inner_func->set("url", remote_func_settings->get("URL").toString());
// auth
if (payload->has("AUTH_METHOD"))
if (remote_func_settings->has("AUTH_METHOD"))
{
inner_func->set("auth_method", payload->get("AUTH_METHOD").toString());
Poco::JSON::Object::Ptr auth_context = new Poco::JSON::Object();
auth_context->set("key_name", payload->get("AUTH_HEADER").toString());
auth_context->set("key_value", payload->get("AUTH_KEY").toString());
inner_func->set("auth_context", auth_context);
auto auth_method(remote_func_settings->get("AUTH_METHOD").toString());
inner_func->set("auth_method", auth_method);
if (auth_method == "auth_header")
{
Poco::JSON::Object::Ptr auth_context = new Poco::JSON::Object();
auth_context->set("key_name", remote_func_settings->get("AUTH_HEADER").toString());
auth_context->set("key_value", remote_func_settings->get("AUTH_KEY").toString());
inner_func->set("auth_context", auth_context);
}
}
func->set("function", inner_func);
/// Remote function don't have source, return early.
Expand Down
4 changes: 2 additions & 2 deletions src/Parsers/ASTCreateFunctionQuery.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class ASTCreateFunctionQuery : public IAST, public ASTQueryWithOnCluster
public:
ASTPtr function_name;
ASTPtr function_core;
Poco::JSON::Object::Ptr payload;
Poco::JSON::Object::Ptr remote_func_settings;

bool or_replace = false;
bool if_not_exists = false;
Expand Down Expand Up @@ -44,7 +44,7 @@ class ASTCreateFunctionQuery : public IAST, public ASTQueryWithOnCluster
bool isJavaScript() const noexcept { return lang == "JavaScript"; }

/// If it is a JavaScript UDF
bool isRemote() const noexcept { return lang == "remote"; }
bool isRemote() const noexcept { return lang == "Remote"; }
/// proton: ends
};

Expand Down
76 changes: 44 additions & 32 deletions src/Parsers/ParserCreateFunctionQuery.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
#include <Parsers/ExpressionListParsers.h>

/// proton: starts
#include <Parsers/ASTLiteral.h>
#include <Parsers/ParserCreateQuery.h>
#include <Parsers/Streaming/ParserArguments.h>
#include <Parsers/ASTLiteral.h>

#include <Poco/JSON/Object.h>
/// proton: ends
Expand All @@ -21,9 +21,10 @@ namespace DB

namespace ErrorCodes
{
extern const int AGGREGATE_FUNCTION_NOT_APPLICABLE;
extern const int AGGREGATE_FUNCTION_NOT_APPLICABLE;
extern const int UNKNOWN_FUNCTION;
}
bool ParserCreateFunctionQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Expected & expected, [[ maybe_unused ]] bool hint)
bool ParserCreateFunctionQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Expected & expected, [[maybe_unused]] bool hint)
{
ParserKeyword s_create("CREATE");
ParserKeyword s_function("FUNCTION");
Expand All @@ -32,7 +33,7 @@ bool ParserCreateFunctionQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Exp
ParserKeyword s_aggr_function("AGGREGATE FUNCTION");
ParserKeyword s_returns("RETURNS");
ParserKeyword s_javascript_type("LANGUAGE JAVASCRIPT");
ParserKeyword s_type_remote("TYPE Remote");
ParserKeyword s_remote("REMOTE FUNCTION");
ParserKeyword s_url("URL");
ParserKeyword s_auth_method("AUTH_METHOD");
ParserKeyword s_auth_header("AUTH_HEADER");
Expand Down Expand Up @@ -79,10 +80,12 @@ bool ParserCreateFunctionQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Exp
/// proton: starts
if (!s_function.ignore(pos, expected))
{
if(!s_aggr_function.ignore(pos, expected))
if (s_aggr_function.ignore(pos, expected))
is_aggregation = true;
else if (s_remote.ignore(pos, expected))
is_remote = true;
else
return false;

is_aggregation = true;
}
/// proton: ends

Expand All @@ -106,14 +109,11 @@ bool ParserCreateFunctionQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Exp
/// proton: starts
if (is_new_syntax && s_returns.ignore(pos, expected))
{
if(!return_p.parse(pos, return_type, expected))
if (!return_p.parse(pos, return_type, expected))
return false;

if (s_javascript_type.ignore(pos, expected))
is_javascript_func = true;
else if (s_type_remote.ignore(pos,expected)){
is_remote = true;
}

if (!is_remote && !s_as.ignore(pos, expected))
return false;
Expand All @@ -131,32 +131,44 @@ bool ParserCreateFunctionQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Exp
if (!lambda_p.parse(pos, function_core, expected))
return false;
}
Poco::JSON::Object::Ptr payload = new Poco::JSON::Object();
if (is_remote){
if (is_aggregation){
throw Exception("Remote udf can not be an aggregate function",ErrorCodes::AGGREGATE_FUNCTION_NOT_APPLICABLE);
Poco::JSON::Object::Ptr remote_func_settings = new Poco::JSON::Object();
if (is_remote)
{
if (is_aggregation)
{
throw Exception("Remote udf can not be an aggregate function", ErrorCodes::AGGREGATE_FUNCTION_NOT_APPLICABLE);
}
if (!s_url.ignore(pos,expected))
if (!s_url.ignore(pos, expected))
return false;
if (!value.parse(pos, url, expected))
return false;
function_core = std::make_shared<ASTLiteral>(Field());
payload->set("URL", url->as<ASTLiteral>()->value.safeGet<String>());
if (s_auth_method.ignore(pos,expected)){
remote_func_settings->set("URL", url->as<ASTLiteral>()->value.safeGet<String>());
if (s_auth_method.ignore(pos, expected))
{
if (!value.parse(pos, auth_method, expected))
return false;
if (!s_auth_header.ignore(pos, expected))
return false;
if (!value.parse(pos, auth_header, expected))
return false;
if (!s_auth_key.ignore(pos, expected))
return false;
if (!value.parse(pos, auth_key, expected))
return false;
payload->set("AUTH_METHOD", auth_method->as<ASTLiteral>()->value.safeGet<String>());
payload->set("AUTH_HEADER", auth_header->as<ASTLiteral>()->value.safeGet<String>());
payload->set("AUTH_KEY", auth_key->as<ASTLiteral>()->value.safeGet<String>());
auto method_str = auth_method->as<ASTLiteral>()->value.safeGet<String>();
if (method_str == "auth_header")
{
if (!s_auth_header.ignore(pos, expected))
return false;
if (!value.parse(pos, auth_header, expected))
return false;
if (!s_auth_key.ignore(pos, expected))
return false;
if (!value.parse(pos, auth_key, expected))
return false;
remote_func_settings->set("AUTH_HEADER", auth_header->as<ASTLiteral>()->value.safeGet<String>());
remote_func_settings->set("AUTH_KEY", auth_key->as<ASTLiteral>()->value.safeGet<String>());
}
else if (method_str != "none")
{
throw Exception("Auth_method must be 'none' or 'auth_header'", ErrorCodes::UNKNOWN_FUNCTION);
}
remote_func_settings->set("AUTH_METHOD", method_str);
}

function_core = std::make_shared<ASTLiteral>(Field());
}
/// proton: ends

Expand All @@ -175,10 +187,10 @@ bool ParserCreateFunctionQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Exp

/// proton: starts
create_function_query->is_aggregation = is_aggregation;
create_function_query->lang = is_javascript_func ? "JavaScript" : is_remote ? "remote" : "SQL";
create_function_query->lang = is_javascript_func ? "JavaScript" : is_remote ? "Remote" : "SQL";
create_function_query->arguments = std::move(arguments);
create_function_query->return_type = std::move(return_type);
create_function_query->payload = payload;
create_function_query->remote_func_settings = remote_func_settings;
/// proton: ends

return true;
Expand Down
119 changes: 119 additions & 0 deletions src/Parsers/tests/gtest_create_remote_func_parser.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#include <IO/WriteBufferFromString.h>
#include <Parsers/ASTCreateFunctionQuery.h>
#include <Parsers/ASTFunctionWithKeyValueArguments.h>
#include <Parsers/ASTIdentifier.h>
#include <Parsers/ASTLiteral.h>
#include <Parsers/DumpASTNode.h>
#include <Parsers/ParserCreateFunctionQuery.h>
#include <Parsers/ParserCreateQuery.h>
#include <Parsers/ParserDropQuery.h>
#include <Parsers/TablePropertiesQueriesASTs.h>
#include <Parsers/formatAST.h>
#include <Parsers/parseQuery.h>
#include <Parsers/queryToString.h>
#include <base/types.h>

#include <gtest/gtest.h>
#include <Poco/JSON/Parser.h>
#include "Common/Exception.h"
#include "Parsers/IParser.h"

using namespace DB;

TEST(ParserCreateRemoteFunctionQuery, UDFNoHeaderMethod)
{
String input = "CREATE REMOTE FUNCTION ip_lookup(ip string) RETURNS string "
"URL 'https://hn6wip76uexaeusz5s7bh3e4u40lrrrz.lambda-url.us-west-2.on.aws/';";
ParserCreateFunctionQuery parser;
ASTPtr ast = parseQuery(parser, input.data(), input.data() + input.size(), "", 0, 0);
ASTCreateFunctionQuery * create = ast->as<ASTCreateFunctionQuery>();
EXPECT_EQ(create->getFunctionName(), "ip_lookup");
EXPECT_EQ(create->lang, "Remote");
EXPECT_NE(create->function_core, nullptr);
EXPECT_NE(create->arguments, nullptr);

/// Check arguments
String args = queryToString(*create->arguments.get(), true);
EXPECT_EQ(args, "(ip string)");

/// Check return type
String ret = queryToString(*create->return_type.get(), true);
EXPECT_EQ(ret, "string");

auto remote_func_settings = create->remote_func_settings;
EXPECT_EQ(remote_func_settings->get("URL").toString(), "https://hn6wip76uexaeusz5s7bh3e4u40lrrrz.lambda-url.us-west-2.on.aws/");
EXPECT_FALSE(remote_func_settings->has("AUTH_METHOD"));
EXPECT_FALSE(remote_func_settings->has("AUTH_HEADER"));
EXPECT_FALSE(remote_func_settings->has("AUTH_KEY"));
}

TEST(ParserCreateRemoteFunctionQuery, UDFHeaderMethodIsNone)
{
String input = "CREATE REMOTE FUNCTION ip_lookup(ip string) RETURNS string "
"URL 'https://hn6wip76uexaeusz5s7bh3e4u40lrrrz.lambda-url.us-west-2.on.aws/'"
"AUTH_METHOD 'none'";
ParserCreateFunctionQuery parser;
ASTPtr ast = parseQuery(parser, input.data(), input.data() + input.size(), "", 0, 0);
ASTCreateFunctionQuery * create = ast->as<ASTCreateFunctionQuery>();
EXPECT_EQ(create->getFunctionName(), "ip_lookup");
EXPECT_EQ(create->lang, "Remote");
EXPECT_NE(create->function_core, nullptr);
EXPECT_NE(create->arguments, nullptr);

/// Check arguments
String args = queryToString(*create->arguments.get(), true);
EXPECT_EQ(args, "(ip string)");

/// Check return type
String ret = queryToString(*create->return_type.get(), true);
EXPECT_EQ(ret, "string");

auto remote_func_settings = create->remote_func_settings;
EXPECT_EQ(remote_func_settings->get("URL").toString(), "https://hn6wip76uexaeusz5s7bh3e4u40lrrrz.lambda-url.us-west-2.on.aws/");
EXPECT_EQ(remote_func_settings->get("AUTH_METHOD").toString(), "none");
EXPECT_FALSE(remote_func_settings->has("AUTH_HEADER"));
EXPECT_FALSE(remote_func_settings->has("AUTH_KEY"));
}

TEST(ParserCreateRemoteFunctionQuery, UDFHeaderMethodIsAuthHeader)
{
String input = "CREATE REMOTE FUNCTION ip_lookup(ip string) RETURNS string "
"URL 'https://hn6wip76uexaeusz5s7bh3e4u40lrrrz.lambda-url.us-west-2.on.aws/'"
"AUTH_METHOD 'auth_header'"
"AUTH_HEADER 'auth'"
"AUTH_KEY 'proton'";
ParserCreateFunctionQuery parser;
ASTPtr ast = parseQuery(parser, input.data(), input.data() + input.size(), "", 0, 0);
ASTCreateFunctionQuery * create = ast->as<ASTCreateFunctionQuery>();
EXPECT_EQ(create->getFunctionName(), "ip_lookup");
EXPECT_EQ(std::string(magic_enum::enum_name(create->lang));, "Remote");
EXPECT_NE(create->function_core, nullptr);
EXPECT_NE(create->arguments, nullptr);

/// Check arguments
String args = queryToString(*create->arguments.get(), true);
EXPECT_EQ(args, "(ip string)");

/// Check return type
String ret = queryToString(*create->return_type.get(), true);
EXPECT_EQ(ret, "string");

auto remote_func_settings = create->remote_func_settings;
EXPECT_EQ(remote_func_settings->get("URL").toString(), "https://hn6wip76uexaeusz5s7bh3e4u40lrrrz.lambda-url.us-west-2.on.aws/");
EXPECT_EQ(remote_func_settings->get("AUTH_METHOD").toString(), "auth_header");
EXPECT_TRUE(remote_func_settings->has("AUTH_HEADER"));
EXPECT_TRUE(remote_func_settings->has("AUTH_KEY"));
EXPECT_EQ(remote_func_settings->get("AUTH_HEADER"), "auth");
EXPECT_EQ(remote_func_settings->get("AUTH_KEY"), "proton");
}


TEST(ParserCreateRemoteFunctionQuery, UDFHeaderMethodIsOther)
{
String input = "CREATE REMOTE FUNCTION ip_lookup(ip string) RETURNS string "
"URL 'https://hn6wip76uexaeusz5s7bh3e4u40lrrrz.lambda-url.us-west-2.on.aws/'"
"AUTH_METHOD 'token'";
ParserCreateFunctionQuery parser;
EXPECT_THROW(parseQuery(parser, input.data(), input.data() + input.size(), "", 0, 0), Exception);

}
Loading

0 comments on commit 750f022

Please sign in to comment.