From 8a0fb6612d6eeaa302b5c991e22b510088032a83 Mon Sep 17 00:00:00 2001 From: xufei Date: Mon, 2 Sep 2019 16:45:05 +0800 Subject: [PATCH] support udf like with 3 arguments (#212) * support udf like with 3 arguments * address comments * add some comments --- dbms/src/Flash/Coprocessor/DAGUtils.cpp | 2 +- dbms/src/Functions/FunctionsStringSearch.cpp | 6 ++ dbms/src/Functions/FunctionsStringSearch.h | 102 ++++++++++++++++++- 3 files changed, 105 insertions(+), 5 deletions(-) diff --git a/dbms/src/Flash/Coprocessor/DAGUtils.cpp b/dbms/src/Flash/Coprocessor/DAGUtils.cpp index 9359bb7de06..90fe7cb1055 100644 --- a/dbms/src/Flash/Coprocessor/DAGUtils.cpp +++ b/dbms/src/Flash/Coprocessor/DAGUtils.cpp @@ -624,7 +624,7 @@ std::unordered_map scalar_func_map({ //{tipb::ScalarFuncSig::IsIPv6, "cast"}, //{tipb::ScalarFuncSig::UUID, "cast"}, - //{tipb::ScalarFuncSig::LikeSig, "cast"}, + {tipb::ScalarFuncSig::LikeSig, "like3Args"}, //{tipb::ScalarFuncSig::RegexpBinarySig, "cast"}, //{tipb::ScalarFuncSig::RegexpSig, "cast"}, diff --git a/dbms/src/Functions/FunctionsStringSearch.cpp b/dbms/src/Functions/FunctionsStringSearch.cpp index 2b356923152..37f905f6606 100644 --- a/dbms/src/Functions/FunctionsStringSearch.cpp +++ b/dbms/src/Functions/FunctionsStringSearch.cpp @@ -1025,6 +1025,10 @@ struct NameLike { static constexpr auto name = "like"; }; +struct NameLike3Args +{ + static constexpr auto name = "like3Args"; +}; struct NameNotLike { static constexpr auto name = "notLike"; @@ -1058,6 +1062,7 @@ using FunctionPositionCaseInsensitiveUTF8 using FunctionMatch = FunctionsStringSearch, NameMatch>; using FunctionLike = FunctionsStringSearch, NameLike>; +using FunctionLike3Args = FunctionsStringSearch, NameLike3Args, 3>; using FunctionNotLike = FunctionsStringSearch, NameNotLike>; using FunctionExtract = FunctionsStringSearchToString; using FunctionReplaceOne = FunctionStringReplace, NameReplaceOne>; @@ -1078,6 +1083,7 @@ void registerFunctionsStringSearch(FunctionFactory & factory) factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); + factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); } diff --git a/dbms/src/Functions/FunctionsStringSearch.h b/dbms/src/Functions/FunctionsStringSearch.h index 9de117464a2..c132d2a1bd3 100644 --- a/dbms/src/Functions/FunctionsStringSearch.h +++ b/dbms/src/Functions/FunctionsStringSearch.h @@ -38,12 +38,14 @@ namespace DB * Warning! At this point, the arguments needle, pattern, n, replacement must be constants. */ +static const UInt8 CH_ESCAPE_CHAR = '\\'; -template +template class FunctionsStringSearch : public IFunction { public: static constexpr auto name = Name::name; + static constexpr auto has_3_args = (num_args == 3); static FunctionPtr create(const Context &) { return std::make_shared(); @@ -56,7 +58,7 @@ class FunctionsStringSearch : public IFunction size_t getNumberOfArguments() const override { - return 2; + return num_args; } DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override @@ -68,10 +70,60 @@ class FunctionsStringSearch : public IFunction if (!arguments[1]->isString()) throw Exception( "Illegal type " + arguments[1]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + if (has_3_args && !arguments[2]->isInteger()) + throw Exception( + "Illegal type " + arguments[2]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); return std::make_shared>(); } + // replace the escape_char in orig_string with '\\' + // this function does not check the validation of the orig_string + // for example, for string "abcd" and escape char 'd', it will + // return "abc\\" + String replaceEscapeChar(String & orig_string, UInt8 escape_char) + { + std::stringstream ss; + for (size_t i = 0; i < orig_string.size(); i++) + { + auto c = orig_string[i]; + if (c == escape_char) + { + if (i+1 != orig_string.size() && orig_string[i+1] == escape_char) + { + // two successive escape char, which means it is trying to escape itself, just remove one + i++; + ss << escape_char; + } + else + { + // https://github.com/pingcap/tidb/blob/master/util/stringutil/string_util.go#L154 + // if any char following escape char that is not [escape_char,'_','%'], it is invalid escape. + // mysql will treat escape character as the origin value even + // the escape sequence is invalid in Go or C. + // e.g., \m is invalid in Go, but in MySQL we will get "m" for select '\m'. + // Following case is correct just for escape \, not for others like +. + // TODO: Add more checks for other escapes. + if (i+1 != orig_string.size() && orig_string[i+1] == CH_ESCAPE_CHAR) + { + continue; + } + ss << CH_ESCAPE_CHAR; + } + } + else if (c == CH_ESCAPE_CHAR) + { + // need to escape this '\\' + ss << CH_ESCAPE_CHAR << CH_ESCAPE_CHAR; + } + else + { + ss << c; + } + } + return ss.str(); + } + void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override { using ResultType = typename Impl::ResultType; @@ -82,10 +134,44 @@ class FunctionsStringSearch : public IFunction const ColumnConst * col_haystack_const = typeid_cast(&*column_haystack); const ColumnConst * col_needle_const = typeid_cast(&*column_needle); + UInt8 escape_char = CH_ESCAPE_CHAR; + if (has_3_args) + { + auto * col_escape_const = typeid_cast(&*block.getByPosition(arguments[2]).column); + bool valid_args = true; + if (col_needle_const == nullptr || col_escape_const == nullptr) + { + valid_args = false; + } + else + { + auto c = col_escape_const->getValue(); + if (c < 0 || c > 255) + { + // todo maybe use more strict constraint + valid_args = false; + } + else + { + escape_char = (UInt8) c; + } + } + if (!valid_args) + { + throw Exception("2nd and 3rd arguments of function " + getName() + " must " + "be constants, and the 3rd argument must between 0 and 255."); + } + } + if (col_haystack_const && col_needle_const) { ResultType res{}; - Impl::constant_constant(col_haystack_const->getValue(), col_needle_const->getValue(), res); + String needle_string = col_needle_const->getValue(); + if (has_3_args && escape_char != CH_ESCAPE_CHAR) + { + needle_string = replaceEscapeChar(needle_string, escape_char); + } + Impl::constant_constant(col_haystack_const->getValue(), needle_string, res); block.getByPosition(result).column = block.getByPosition(result).type->createColumnConst(col_haystack_const->size(), toField(res)); return; } @@ -105,7 +191,15 @@ class FunctionsStringSearch : public IFunction col_needle_vector->getOffsets(), vec_res); else if (col_haystack_vector && col_needle_const) - Impl::vector_constant(col_haystack_vector->getChars(), col_haystack_vector->getOffsets(), col_needle_const->getValue(), vec_res); + { + String needle_string = col_needle_const->getValue(); + if (has_3_args && escape_char != CH_ESCAPE_CHAR) + { + needle_string = replaceEscapeChar(needle_string, escape_char); + } + Impl::vector_constant(col_haystack_vector->getChars(), col_haystack_vector->getOffsets(), + needle_string, vec_res); + } else if (col_haystack_const && col_needle_vector) Impl::constant_vector(col_haystack_const->getValue(), col_needle_vector->getChars(), col_needle_vector->getOffsets(), vec_res); else