Skip to content

Commit

Permalink
support udf like with 3 arguments (#212)
Browse files Browse the repository at this point in the history
* support udf like with 3 arguments

* address comments

* add some comments
  • Loading branch information
windtalker authored and zanmato1984 committed Sep 2, 2019
1 parent 39d1994 commit 8a0fb66
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 5 deletions.
2 changes: 1 addition & 1 deletion dbms/src/Flash/Coprocessor/DAGUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ std::unordered_map<tipb::ScalarFuncSig, String> scalar_func_map({
//{tipb::ScalarFuncSig::IsIPv6, "cast"},
//{tipb::ScalarFuncSig::UUID, "cast"},

//{tipb::ScalarFuncSig::LikeSig, "cast"},
{tipb::ScalarFuncSig::LikeSig, "like3Args"},
//{tipb::ScalarFuncSig::RegexpBinarySig, "cast"},
//{tipb::ScalarFuncSig::RegexpSig, "cast"},

Expand Down
6 changes: 6 additions & 0 deletions dbms/src/Functions/FunctionsStringSearch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,10 @@ struct NameLike
{
static constexpr auto name = "like";
};
struct NameLike3Args
{
static constexpr auto name = "like3Args";
};
struct NameNotLike
{
static constexpr auto name = "notLike";
Expand Down Expand Up @@ -1058,6 +1062,7 @@ using FunctionPositionCaseInsensitiveUTF8

using FunctionMatch = FunctionsStringSearch<MatchImpl<false>, NameMatch>;
using FunctionLike = FunctionsStringSearch<MatchImpl<true>, NameLike>;
using FunctionLike3Args = FunctionsStringSearch<MatchImpl<true>, NameLike3Args, 3>;
using FunctionNotLike = FunctionsStringSearch<MatchImpl<true, true>, NameNotLike>;
using FunctionExtract = FunctionsStringSearchToString<ExtractImpl, NameExtract>;
using FunctionReplaceOne = FunctionStringReplace<ReplaceStringImpl<true>, NameReplaceOne>;
Expand All @@ -1078,6 +1083,7 @@ void registerFunctionsStringSearch(FunctionFactory & factory)
factory.registerFunction<FunctionPositionCaseInsensitiveUTF8>();
factory.registerFunction<FunctionMatch>();
factory.registerFunction<FunctionLike>();
factory.registerFunction<FunctionLike3Args>();
factory.registerFunction<FunctionNotLike>();
factory.registerFunction<FunctionExtract>();
}
Expand Down
102 changes: 98 additions & 4 deletions dbms/src/Functions/FunctionsStringSearch.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,14 @@ namespace DB
* Warning! At this point, the arguments needle, pattern, n, replacement must be constants.
*/

static const UInt8 CH_ESCAPE_CHAR = '\\';

template <typename Impl, typename Name>
template <typename Impl, typename Name, size_t num_args = 2>
class FunctionsStringSearch : public IFunction
{
public:
static constexpr auto name = Name::name;
static constexpr auto has_3_args = (num_args == 3);
static FunctionPtr create(const Context &)
{
return std::make_shared<FunctionsStringSearch>();
Expand All @@ -56,7 +58,7 @@ class FunctionsStringSearch : public IFunction

size_t getNumberOfArguments() const override
{
return 2;
return num_args;
}

DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
Expand All @@ -68,10 +70,60 @@ class FunctionsStringSearch : public IFunction
if (!arguments[1]->isString())
throw Exception(
"Illegal type " + arguments[1]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (has_3_args && !arguments[2]->isInteger())
throw Exception(
"Illegal type " + arguments[2]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

return std::make_shared<DataTypeNumber<typename Impl::ResultType>>();
}

// replace the escape_char in orig_string with '\\'
// this function does not check the validation of the orig_string
// for example, for string "abcd" and escape char 'd', it will
// return "abc\\"
String replaceEscapeChar(String & orig_string, UInt8 escape_char)
{
std::stringstream ss;
for (size_t i = 0; i < orig_string.size(); i++)
{
auto c = orig_string[i];
if (c == escape_char)
{
if (i+1 != orig_string.size() && orig_string[i+1] == escape_char)
{
// two successive escape char, which means it is trying to escape itself, just remove one
i++;
ss << escape_char;
}
else
{
// https://github.com/pingcap/tidb/blob/master/util/stringutil/string_util.go#L154
// if any char following escape char that is not [escape_char,'_','%'], it is invalid escape.
// mysql will treat escape character as the origin value even
// the escape sequence is invalid in Go or C.
// e.g., \m is invalid in Go, but in MySQL we will get "m" for select '\m'.
// Following case is correct just for escape \, not for others like +.
// TODO: Add more checks for other escapes.
if (i+1 != orig_string.size() && orig_string[i+1] == CH_ESCAPE_CHAR)
{
continue;
}
ss << CH_ESCAPE_CHAR;
}
}
else if (c == CH_ESCAPE_CHAR)
{
// need to escape this '\\'
ss << CH_ESCAPE_CHAR << CH_ESCAPE_CHAR;
}
else
{
ss << c;
}
}
return ss.str();
}

void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override
{
using ResultType = typename Impl::ResultType;
Expand All @@ -82,10 +134,44 @@ class FunctionsStringSearch : public IFunction
const ColumnConst * col_haystack_const = typeid_cast<const ColumnConst *>(&*column_haystack);
const ColumnConst * col_needle_const = typeid_cast<const ColumnConst *>(&*column_needle);

UInt8 escape_char = CH_ESCAPE_CHAR;
if (has_3_args)
{
auto * col_escape_const = typeid_cast<const ColumnConst *>(&*block.getByPosition(arguments[2]).column);
bool valid_args = true;
if (col_needle_const == nullptr || col_escape_const == nullptr)
{
valid_args = false;
}
else
{
auto c = col_escape_const->getValue<Int32>();
if (c < 0 || c > 255)
{
// todo maybe use more strict constraint
valid_args = false;
}
else
{
escape_char = (UInt8) c;
}
}
if (!valid_args)
{
throw Exception("2nd and 3rd arguments of function " + getName() + " must "
"be constants, and the 3rd argument must between 0 and 255.");
}
}

if (col_haystack_const && col_needle_const)
{
ResultType res{};
Impl::constant_constant(col_haystack_const->getValue<String>(), col_needle_const->getValue<String>(), res);
String needle_string = col_needle_const->getValue<String>();
if (has_3_args && escape_char != CH_ESCAPE_CHAR)
{
needle_string = replaceEscapeChar(needle_string, escape_char);
}
Impl::constant_constant(col_haystack_const->getValue<String>(), needle_string, res);
block.getByPosition(result).column = block.getByPosition(result).type->createColumnConst(col_haystack_const->size(), toField(res));
return;
}
Expand All @@ -105,7 +191,15 @@ class FunctionsStringSearch : public IFunction
col_needle_vector->getOffsets(),
vec_res);
else if (col_haystack_vector && col_needle_const)
Impl::vector_constant(col_haystack_vector->getChars(), col_haystack_vector->getOffsets(), col_needle_const->getValue<String>(), vec_res);
{
String needle_string = col_needle_const->getValue<String>();
if (has_3_args && escape_char != CH_ESCAPE_CHAR)
{
needle_string = replaceEscapeChar(needle_string, escape_char);
}
Impl::vector_constant(col_haystack_vector->getChars(), col_haystack_vector->getOffsets(),
needle_string, vec_res);
}
else if (col_haystack_const && col_needle_vector)
Impl::constant_vector(col_haystack_const->getValue<String>(), col_needle_vector->getChars(), col_needle_vector->getOffsets(), vec_res);
else
Expand Down

0 comments on commit 8a0fb66

Please sign in to comment.