diff --git a/src/Parsers/ASTCreateFunctionQuery.cpp b/src/Parsers/ASTCreateFunctionQuery.cpp index 512cc1bd2e4..3f59dfd60e7 100644 --- a/src/Parsers/ASTCreateFunctionQuery.cpp +++ b/src/Parsers/ASTCreateFunctionQuery.cpp @@ -1,13 +1,13 @@ -#include #include #include #include #include +#include /// proton: starts -#include -#include #include +#include +#include #include /// proton: ends @@ -26,7 +26,7 @@ ASTPtr ASTCreateFunctionQuery::clone() const res->function_core = function_core->clone(); res->children.push_back(res->function_core); - res->payload = payload; + res->remote_func_settings = remote_func_settings; return res; } @@ -38,8 +38,11 @@ void ASTCreateFunctionQuery::formatImpl(const IAST::FormatSettings & settings, I settings.ostr << "OR REPLACE "; /// proton: starts + bool is_remote = isRemote(); if (is_aggregation) settings.ostr << "AGGREGATE FUNCTION "; + else if (is_remote) + settings.ostr << "REMOTE FUNCTION "; else settings.ostr << "FUNCTION "; /// proton: ends @@ -53,7 +56,6 @@ void ASTCreateFunctionQuery::formatImpl(const IAST::FormatSettings & settings, I /// proton: starts bool is_javascript_func = isJavaScript(); - bool is_remote = isRemote(); if (is_javascript_func || is_remote) { /// arguments @@ -70,11 +72,17 @@ void ASTCreateFunctionQuery::formatImpl(const IAST::FormatSettings & settings, I /// proton: starts if (is_remote) { - settings.ostr << (settings.hilite ? hilite_keyword : "") << fmt::format("\nTYPE Remote \n") << (settings.hilite ? hilite_none : ""); - settings.ostr << fmt::format("URL '{}'\n", payload->get("AUTH_METHOD").toString()); - settings.ostr << fmt::format("AUTH_METHOD '{}'\n", payload->has("AUTH_METHOD") ? payload->get("AUTH_METHOD").toString() : "none"); - settings.ostr << fmt::format("AUTH_HEADER '{}'\n", payload->has("AUTH_HEADER") ? payload->get("AUTH_HEADER").toString() : "none"); - settings.ostr << fmt::format("AUTH_KEY '{}'\n", payload->has("AUTH_KEY") ? payload->get("AUTH_KEY").toString() : "none"); + settings.ostr << fmt::format("\nURL '{}'\n", remote_func_settings->get("URL").toString()); + auto auth_method = remote_func_settings->has("AUTH_METHOD") ? remote_func_settings->get("AUTH_METHOD").toString() : "none"; + settings.ostr << fmt::format("AUTH_METHOD '{}'\n", auth_method); + if (auth_method != "none") + { + settings.ostr << fmt::format( + "AUTH_HEADER '{}'\n", + remote_func_settings->has("AUTH_HEADER") ? remote_func_settings->get("AUTH_HEADER").toString() : "none"); + settings.ostr << fmt::format( + "AUTH_KEY '{}'\n", remote_func_settings->has("AUTH_KEY") ? remote_func_settings->get("AUTH_KEY").toString() : "none"); + } return; } /// proton: ends @@ -144,18 +152,22 @@ Poco::JSON::Object::Ptr ASTCreateFunctionQuery::toJSON() const formatAST(*return_type, return_buf, false); inner_func->set("return_type", return_buf.str()); - /// remote functio + /// remote function if (is_remote) { - inner_func->set("url", payload->get("URL").toString()); + inner_func->set("url", remote_func_settings->get("URL").toString()); // auth - if (payload->has("AUTH_METHOD")) + if (remote_func_settings->has("AUTH_METHOD")) { - inner_func->set("auth_method", payload->get("AUTH_METHOD").toString()); - Poco::JSON::Object::Ptr auth_context = new Poco::JSON::Object(); - auth_context->set("key_name", payload->get("AUTH_HEADER").toString()); - auth_context->set("key_value", payload->get("AUTH_KEY").toString()); - inner_func->set("auth_context", auth_context); + auto auth_method(remote_func_settings->get("AUTH_METHOD").toString()); + inner_func->set("auth_method", auth_method); + if (auth_method == "auth_header") + { + Poco::JSON::Object::Ptr auth_context = new Poco::JSON::Object(); + auth_context->set("key_name", remote_func_settings->get("AUTH_HEADER").toString()); + auth_context->set("key_value", remote_func_settings->get("AUTH_KEY").toString()); + inner_func->set("auth_context", auth_context); + } } func->set("function", inner_func); /// Remote function don't have source, return early. diff --git a/src/Parsers/ASTCreateFunctionQuery.h b/src/Parsers/ASTCreateFunctionQuery.h index 7ff66a65b6d..fdcca91f37f 100644 --- a/src/Parsers/ASTCreateFunctionQuery.h +++ b/src/Parsers/ASTCreateFunctionQuery.h @@ -15,7 +15,7 @@ class ASTCreateFunctionQuery : public IAST, public ASTQueryWithOnCluster public: ASTPtr function_name; ASTPtr function_core; - Poco::JSON::Object::Ptr payload; + Poco::JSON::Object::Ptr remote_func_settings; bool or_replace = false; bool if_not_exists = false; @@ -44,7 +44,7 @@ class ASTCreateFunctionQuery : public IAST, public ASTQueryWithOnCluster bool isJavaScript() const noexcept { return lang == "JavaScript"; } /// If it is a JavaScript UDF - bool isRemote() const noexcept { return lang == "remote"; } + bool isRemote() const noexcept { return lang == "Remote"; } /// proton: ends }; diff --git a/src/Parsers/ParserCreateFunctionQuery.cpp b/src/Parsers/ParserCreateFunctionQuery.cpp index cd8448e966e..560ee11b87f 100644 --- a/src/Parsers/ParserCreateFunctionQuery.cpp +++ b/src/Parsers/ParserCreateFunctionQuery.cpp @@ -8,9 +8,9 @@ #include /// proton: starts +#include #include #include -#include #include /// proton: ends @@ -21,9 +21,10 @@ namespace DB namespace ErrorCodes { - extern const int AGGREGATE_FUNCTION_NOT_APPLICABLE; +extern const int AGGREGATE_FUNCTION_NOT_APPLICABLE; +extern const int UNKNOWN_FUNCTION; } -bool ParserCreateFunctionQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Expected & expected, [[ maybe_unused ]] bool hint) +bool ParserCreateFunctionQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Expected & expected, [[maybe_unused]] bool hint) { ParserKeyword s_create("CREATE"); ParserKeyword s_function("FUNCTION"); @@ -32,7 +33,7 @@ bool ParserCreateFunctionQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Exp ParserKeyword s_aggr_function("AGGREGATE FUNCTION"); ParserKeyword s_returns("RETURNS"); ParserKeyword s_javascript_type("LANGUAGE JAVASCRIPT"); - ParserKeyword s_type_remote("TYPE Remote"); + ParserKeyword s_remote("REMOTE FUNCTION"); ParserKeyword s_url("URL"); ParserKeyword s_auth_method("AUTH_METHOD"); ParserKeyword s_auth_header("AUTH_HEADER"); @@ -79,10 +80,12 @@ bool ParserCreateFunctionQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Exp /// proton: starts if (!s_function.ignore(pos, expected)) { - if(!s_aggr_function.ignore(pos, expected)) + if (s_aggr_function.ignore(pos, expected)) + is_aggregation = true; + else if (s_remote.ignore(pos, expected)) + is_remote = true; + else return false; - - is_aggregation = true; } /// proton: ends @@ -106,14 +109,11 @@ bool ParserCreateFunctionQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Exp /// proton: starts if (is_new_syntax && s_returns.ignore(pos, expected)) { - if(!return_p.parse(pos, return_type, expected)) + if (!return_p.parse(pos, return_type, expected)) return false; if (s_javascript_type.ignore(pos, expected)) is_javascript_func = true; - else if (s_type_remote.ignore(pos,expected)){ - is_remote = true; - } if (!is_remote && !s_as.ignore(pos, expected)) return false; @@ -131,32 +131,44 @@ bool ParserCreateFunctionQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Exp if (!lambda_p.parse(pos, function_core, expected)) return false; } - Poco::JSON::Object::Ptr payload = new Poco::JSON::Object(); - if (is_remote){ - if (is_aggregation){ - throw Exception("Remote udf can not be an aggregate function",ErrorCodes::AGGREGATE_FUNCTION_NOT_APPLICABLE); + Poco::JSON::Object::Ptr remote_func_settings = new Poco::JSON::Object(); + if (is_remote) + { + if (is_aggregation) + { + throw Exception("Remote udf can not be an aggregate function", ErrorCodes::AGGREGATE_FUNCTION_NOT_APPLICABLE); } - if (!s_url.ignore(pos,expected)) + if (!s_url.ignore(pos, expected)) return false; if (!value.parse(pos, url, expected)) return false; - function_core = std::make_shared(Field()); - payload->set("URL", url->as()->value.safeGet()); - if (s_auth_method.ignore(pos,expected)){ + remote_func_settings->set("URL", url->as()->value.safeGet()); + if (s_auth_method.ignore(pos, expected)) + { if (!value.parse(pos, auth_method, expected)) return false; - if (!s_auth_header.ignore(pos, expected)) - return false; - if (!value.parse(pos, auth_header, expected)) - return false; - if (!s_auth_key.ignore(pos, expected)) - return false; - if (!value.parse(pos, auth_key, expected)) - return false; - payload->set("AUTH_METHOD", auth_method->as()->value.safeGet()); - payload->set("AUTH_HEADER", auth_header->as()->value.safeGet()); - payload->set("AUTH_KEY", auth_key->as()->value.safeGet()); + auto method_str = auth_method->as()->value.safeGet(); + if (method_str == "auth_header") + { + if (!s_auth_header.ignore(pos, expected)) + return false; + if (!value.parse(pos, auth_header, expected)) + return false; + if (!s_auth_key.ignore(pos, expected)) + return false; + if (!value.parse(pos, auth_key, expected)) + return false; + remote_func_settings->set("AUTH_HEADER", auth_header->as()->value.safeGet()); + remote_func_settings->set("AUTH_KEY", auth_key->as()->value.safeGet()); + } + else if (method_str != "none") + { + throw Exception("Auth_method must be 'none' or 'auth_header'", ErrorCodes::UNKNOWN_FUNCTION); + } + remote_func_settings->set("AUTH_METHOD", method_str); } + + function_core = std::make_shared(Field()); } /// proton: ends @@ -175,10 +187,10 @@ bool ParserCreateFunctionQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Exp /// proton: starts create_function_query->is_aggregation = is_aggregation; - create_function_query->lang = is_javascript_func ? "JavaScript" : is_remote ? "remote" : "SQL"; + create_function_query->lang = is_javascript_func ? "JavaScript" : is_remote ? "Remote" : "SQL"; create_function_query->arguments = std::move(arguments); create_function_query->return_type = std::move(return_type); - create_function_query->payload = payload; + create_function_query->remote_func_settings = remote_func_settings; /// proton: ends return true; diff --git a/src/Parsers/tests/gtest_create_remote_func_parser.cpp b/src/Parsers/tests/gtest_create_remote_func_parser.cpp new file mode 100644 index 00000000000..6aaa831be84 --- /dev/null +++ b/src/Parsers/tests/gtest_create_remote_func_parser.cpp @@ -0,0 +1,119 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include "Common/Exception.h" +#include "Parsers/IParser.h" + +using namespace DB; + +TEST(ParserCreateRemoteFunctionQuery, UDFNoHeaderMethod) +{ + String input = "CREATE REMOTE FUNCTION ip_lookup(ip string) RETURNS string " + "URL 'https://hn6wip76uexaeusz5s7bh3e4u40lrrrz.lambda-url.us-west-2.on.aws/';"; + ParserCreateFunctionQuery parser; + ASTPtr ast = parseQuery(parser, input.data(), input.data() + input.size(), "", 0, 0); + ASTCreateFunctionQuery * create = ast->as(); + EXPECT_EQ(create->getFunctionName(), "ip_lookup"); + EXPECT_EQ(create->lang, "Remote"); + EXPECT_NE(create->function_core, nullptr); + EXPECT_NE(create->arguments, nullptr); + + /// Check arguments + String args = queryToString(*create->arguments.get(), true); + EXPECT_EQ(args, "(ip string)"); + + /// Check return type + String ret = queryToString(*create->return_type.get(), true); + EXPECT_EQ(ret, "string"); + + auto remote_func_settings = create->remote_func_settings; + EXPECT_EQ(remote_func_settings->get("URL").toString(), "https://hn6wip76uexaeusz5s7bh3e4u40lrrrz.lambda-url.us-west-2.on.aws/"); + EXPECT_FALSE(remote_func_settings->has("AUTH_METHOD")); + EXPECT_FALSE(remote_func_settings->has("AUTH_HEADER")); + EXPECT_FALSE(remote_func_settings->has("AUTH_KEY")); +} + +TEST(ParserCreateRemoteFunctionQuery, UDFHeaderMethodIsNone) +{ + String input = "CREATE REMOTE FUNCTION ip_lookup(ip string) RETURNS string " + "URL 'https://hn6wip76uexaeusz5s7bh3e4u40lrrrz.lambda-url.us-west-2.on.aws/'" + "AUTH_METHOD 'none'"; + ParserCreateFunctionQuery parser; + ASTPtr ast = parseQuery(parser, input.data(), input.data() + input.size(), "", 0, 0); + ASTCreateFunctionQuery * create = ast->as(); + EXPECT_EQ(create->getFunctionName(), "ip_lookup"); + EXPECT_EQ(create->lang, "Remote"); + EXPECT_NE(create->function_core, nullptr); + EXPECT_NE(create->arguments, nullptr); + + /// Check arguments + String args = queryToString(*create->arguments.get(), true); + EXPECT_EQ(args, "(ip string)"); + + /// Check return type + String ret = queryToString(*create->return_type.get(), true); + EXPECT_EQ(ret, "string"); + + auto remote_func_settings = create->remote_func_settings; + EXPECT_EQ(remote_func_settings->get("URL").toString(), "https://hn6wip76uexaeusz5s7bh3e4u40lrrrz.lambda-url.us-west-2.on.aws/"); + EXPECT_EQ(remote_func_settings->get("AUTH_METHOD").toString(), "none"); + EXPECT_FALSE(remote_func_settings->has("AUTH_HEADER")); + EXPECT_FALSE(remote_func_settings->has("AUTH_KEY")); +} + +TEST(ParserCreateRemoteFunctionQuery, UDFHeaderMethodIsAuthHeader) +{ + String input = "CREATE REMOTE FUNCTION ip_lookup(ip string) RETURNS string " + "URL 'https://hn6wip76uexaeusz5s7bh3e4u40lrrrz.lambda-url.us-west-2.on.aws/'" + "AUTH_METHOD 'auth_header'" + "AUTH_HEADER 'auth'" + "AUTH_KEY 'proton'"; + ParserCreateFunctionQuery parser; + ASTPtr ast = parseQuery(parser, input.data(), input.data() + input.size(), "", 0, 0); + ASTCreateFunctionQuery * create = ast->as(); + EXPECT_EQ(create->getFunctionName(), "ip_lookup"); + EXPECT_EQ(std::string(magic_enum::enum_name(create->lang));, "Remote"); + EXPECT_NE(create->function_core, nullptr); + EXPECT_NE(create->arguments, nullptr); + + /// Check arguments + String args = queryToString(*create->arguments.get(), true); + EXPECT_EQ(args, "(ip string)"); + + /// Check return type + String ret = queryToString(*create->return_type.get(), true); + EXPECT_EQ(ret, "string"); + + auto remote_func_settings = create->remote_func_settings; + EXPECT_EQ(remote_func_settings->get("URL").toString(), "https://hn6wip76uexaeusz5s7bh3e4u40lrrrz.lambda-url.us-west-2.on.aws/"); + EXPECT_EQ(remote_func_settings->get("AUTH_METHOD").toString(), "auth_header"); + EXPECT_TRUE(remote_func_settings->has("AUTH_HEADER")); + EXPECT_TRUE(remote_func_settings->has("AUTH_KEY")); + EXPECT_EQ(remote_func_settings->get("AUTH_HEADER"), "auth"); + EXPECT_EQ(remote_func_settings->get("AUTH_KEY"), "proton"); +} + + +TEST(ParserCreateRemoteFunctionQuery, UDFHeaderMethodIsOther) +{ + String input = "CREATE REMOTE FUNCTION ip_lookup(ip string) RETURNS string " + "URL 'https://hn6wip76uexaeusz5s7bh3e4u40lrrrz.lambda-url.us-west-2.on.aws/'" + "AUTH_METHOD 'token'"; + ParserCreateFunctionQuery parser; + EXPECT_THROW(parseQuery(parser, input.data(), input.data() + input.size(), "", 0, 0), Exception); + +} \ No newline at end of file diff --git a/tests/stream/test_stream_smoke/0022_udf3_create_remote_func.yaml b/tests/stream/test_stream_smoke/0022_udf3_create_remote_func.yaml index cad0677db20..2f7a5c62fc5 100644 --- a/tests/stream/test_stream_smoke/0022_udf3_create_remote_func.yaml +++ b/tests/stream/test_stream_smoke/0022_udf3_create_remote_func.yaml @@ -34,9 +34,9 @@ tests: query_id: udf-29-0 wait: 1 query: | - CREATE FUNCTION ip_lookup(ip string) RETURNS string - TYPE Remote - URL 'https://hn6wip76uexaeusz5s7bh3e4u40lrrrz.lambda-url.us-west-2.on.aws/'; + CREATE Remote FUNCTION ip_lookup(ip string) RETURNS string + URL 'https://hn6wip76uexaeusz5s7bh3e4u40lrrrz.lambda-url.us-west-2.on.aws/' + AUTH_METHOD 'none'; - client: python query_id: udf-29-1 @@ -45,7 +45,7 @@ tests: query_type: table wait: 5 query: | - select ip_lookup('1.1.1.1'); + select ip_lookup('127.0.0.1'); - client: python query_id: udf-29-2 @@ -67,16 +67,15 @@ tests: expected_results: - query_id: udf-29-1 expected_results: - - [ '{"ip":"1.1.1.1","hostname":"one.one.one.one","anycast":true,"city":"Englewood","region":"Colorado","country":"US","loc":"39.6123,-104.8799","org":"AS13335 Cloudflare, Inc.","postal":"80111","timezone":"America/Denver"}' ] + - [ '{"ip":"127.0.0.1","bogon":true}' ] - query_id: udf-29-2 expected_results: - [ '{"status":404,"error":{"title":"Wrong ip","message":"Please provide a valid IP address"}}'] - - id: 1 tags: - udf3_create_remote_func - name: create remote uda - description: create remote uda failed + name: create remote udf with auth + description: SQL - remote UDF steps: - statements: - client: python @@ -88,11 +87,43 @@ tests: query_id: udf-30-0 wait: 1 query: | - CREATE AGGREGATE FUNCTION ip_lookup(ip string) RETURNS string - TYPE Remote - URL 'https://hn6wip76uexaeusz5s7bh3e4u40lrrrz.lambda-url.us-west-2.on.aws/'; + CREATE Remote FUNCTION ip_lookup(ip string) RETURNS string + URL 'https://hn6wip76uexaeusz5s7bh3e4u40lrrrz.lambda-url.us-west-2.on.aws/' + AUTH_METHOD 'auth_header' + AUTH_HEADER 'auth' + AUTH_KEY 'proton'; + - client: python + query_id: udf-30-1 + query_end_timer: 7 + depends_on_done: udf-30-0 + query_type: table + wait: 5 + query: | + select ip_lookup('127.0.0.1'); + + - client: python + query_id: udf-30-2 + query_end_timer: 7 + depends_on_done: udf-30-0 + query_type: table + wait: 5 + query: | + select ip_lookup('1'); + + - client: python + query_type: table + query_id: udf-30-3 + depends_on_done: udf-30-0 + wait: 1 + query: | + DROP FUNCTION ip_lookup; + expected_results: - - query_id: udf-30-0 - expected_results: "error_code:154" - + - query_id: udf-30-1 + expected_results: + - [ '{"ip":"127.0.0.1","bogon":true}' ] + - query_id: udf-30-2 + expected_results: + - [ '{"status":404,"error":{"title":"Wrong ip","message":"Please provide a valid IP address"}}'] +