Skip to content

Commit 8892fd1

Browse files
FLASH-783 NULL behavior fix for in expr (#386)
* fix expr in error * Revert "fix expr in error" This reverts commit 799b89d. * add tidbIn * 1. address comments, 2. fix build error * address comments Co-authored-by: ruoxi <[email protected]>
1 parent 12faa8f commit 8892fd1

File tree

10 files changed

+240
-53
lines changed

10 files changed

+240
-53
lines changed

Diff for: dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <Flash/Coprocessor/DAGCodec.h>
1010
#include <Flash/Coprocessor/DAGUtils.h>
1111
#include <Functions/FunctionFactory.h>
12+
#include <Functions/FunctionHelpers.h>
1213
#include <Interpreters/Context.h>
1314
#include <Interpreters/Set.h>
1415
#include <Interpreters/convertFieldToType.h>
@@ -416,7 +417,7 @@ void DAGExpressionAnalyzer::makeExplicitSetForIndex(const tipb::Expr & expr, con
416417
}
417418
const String & func_name = getFunctionName(expr);
418419
// only support col_name in (value_list)
419-
if (isInOrGlobalInOperator(func_name) && expr.children(0).tp() == tipb::ExprType::ColumnRef && !prepared_sets.count(&expr))
420+
if (functionIsInOrGlobalInOperator(func_name) && expr.children(0).tp() == tipb::ExprType::ColumnRef && !prepared_sets.count(&expr))
420421
{
421422
NamesAndTypesList column_list;
422423
for (const auto & col : getCurrentInputColumns())
@@ -488,7 +489,7 @@ String DAGExpressionAnalyzer::getActionsForInOperator(const tipb::Expr & expr, E
488489
// key not in (const1, const2, non_const1, non_const2) => and(key not in (const1, const2), key not eq non_const1, key not eq non_const2)
489490
argument_names.clear();
490491
argument_names.push_back(expr_name);
491-
bool is_not_in = func_name == "notIn" || func_name == "globalNotIn";
492+
bool is_not_in = func_name == "notIn" || func_name == "globalNotIn" || func_name == "tidbNotIn";
492493
for (const tipb::Expr * non_constant_expr : set->remaining_exprs)
493494
{
494495
Names eq_arg_names;
@@ -532,7 +533,7 @@ String DAGExpressionAnalyzer::getActions(const tipb::Expr & expr, ExpressionActi
532533
}
533534
const String & func_name = getFunctionName(expr);
534535

535-
if (isInOrGlobalInOperator(func_name))
536+
if (functionIsInOrGlobalInOperator(func_name))
536537
{
537538
return getActionsForInOperator(expr, actions);
538539
}

Diff for: dbms/src/Flash/Coprocessor/DAGUtils.cpp

+5-6
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <Core/Types.h>
44
#include <Flash/Coprocessor/DAGCodec.h>
5+
#include <Functions/FunctionHelpers.h>
56
#include <Interpreters/Context.h>
67
#include <Storages/Transaction/Datum.h>
78
#include <Storages/Transaction/TiDB.h>
@@ -119,7 +120,7 @@ String exprToString(const tipb::Expr & expr, const std::vector<NameAndTypePair>
119120
throw Exception(tipb::ExprType_Name(expr.tp()) + " not supported", ErrorCodes::UNSUPPORTED_METHOD);
120121
}
121122
// build function expr
122-
if (isInOrGlobalInOperator(func_name))
123+
if (functionIsInOrGlobalInOperator(func_name))
123124
{
124125
// for in, we could not represent the function expr using func_name(param1, param2, ...)
125126
ss << exprToString(expr.children(0), input_col) << " " << func_name << " (";
@@ -263,8 +264,6 @@ String getColumnNameForColumnExpr(const tipb::Expr & expr, const std::vector<Nam
263264
return input_col[column_index].name;
264265
}
265266

266-
bool isInOrGlobalInOperator(const String & name) { return name == "in" || name == "notIn" || name == "globalIn" || name == "globalNotIn"; }
267-
268267
// for some historical or unknown reasons, TiDB might set a invalid
269268
// field type. This function checks if the expr has a valid field type
270269
// so far the known invalid field types are:
@@ -573,9 +572,9 @@ std::unordered_map<tipb::ScalarFuncSig, String> scalar_func_map({
573572
//{tipb::ScalarFuncSig::ValuesString, "cast"},
574573
//{tipb::ScalarFuncSig::ValuesTime, "cast"},
575574

576-
{tipb::ScalarFuncSig::InInt, "in"}, {tipb::ScalarFuncSig::InReal, "in"}, {tipb::ScalarFuncSig::InString, "in"},
577-
{tipb::ScalarFuncSig::InDecimal, "in"}, {tipb::ScalarFuncSig::InTime, "in"}, {tipb::ScalarFuncSig::InDuration, "in"},
578-
{tipb::ScalarFuncSig::InJson, "in"},
575+
{tipb::ScalarFuncSig::InInt, "tidbIn"}, {tipb::ScalarFuncSig::InReal, "tidbIn"}, {tipb::ScalarFuncSig::InString, "tidbIn"},
576+
{tipb::ScalarFuncSig::InDecimal, "tidbIn"}, {tipb::ScalarFuncSig::InTime, "tidbIn"}, {tipb::ScalarFuncSig::InDuration, "tidbIn"},
577+
{tipb::ScalarFuncSig::InJson, "tidbIn"},
579578

580579
{tipb::ScalarFuncSig::IfNullInt, "ifNull"}, {tipb::ScalarFuncSig::IfNullReal, "ifNull"}, {tipb::ScalarFuncSig::IfNullString, "ifNull"},
581580
{tipb::ScalarFuncSig::IfNullDecimal, "ifNull"}, {tipb::ScalarFuncSig::IfNullTime, "ifNull"},

Diff for: dbms/src/Flash/Coprocessor/DAGUtils.h

-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ bool isColumnExpr(const tipb::Expr & expr);
2525
String getColumnNameForColumnExpr(const tipb::Expr & expr, const std::vector<NameAndTypePair> & input_col);
2626
const String & getTypeName(const tipb::Expr & expr);
2727
String exprToString(const tipb::Expr & expr, const std::vector<NameAndTypePair> & input_col);
28-
bool isInOrGlobalInOperator(const String & name);
2928
bool exprHasValidFieldType(const tipb::Expr & expr);
3029
void constructStringLiteralTiExpr(tipb::Expr & expr, const String & value);
3130
void constructInt64LiteralTiExpr(tipb::Expr & expr, Int64 value);

Diff for: dbms/src/Functions/FunctionHelpers.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -96,4 +96,14 @@ Block createBlockWithNestedColumns(const Block & block, const ColumnNumbers & ar
9696
return createBlockWithNestedColumnsImpl(block, args_set);
9797
}
9898

99+
bool functionIsInOperator(const String & name)
100+
{
101+
return name == "in" || name == "notIn" || name == "tidbIn" || name == "tidbNotIn";
102+
}
103+
104+
bool functionIsInOrGlobalInOperator(const String & name)
105+
{
106+
return name == "in" || name == "notIn" || name == "globalIn" || name == "globalNotIn" || name == "tidbIn" || name == "tidbNotIn";
107+
}
108+
99109
}

Diff for: dbms/src/Functions/FunctionHelpers.h

+3
Original file line numberDiff line numberDiff line change
@@ -99,5 +99,8 @@ Block createBlockWithNestedColumns(const Block & block, const ColumnNumbers & ar
9999
/// Similar function as above. Additionally transform the result type if needed.
100100
Block createBlockWithNestedColumns(const Block & block, const ColumnNumbers & args, size_t result);
101101

102+
bool functionIsInOperator(const String & name);
103+
104+
bool functionIsInOrGlobalInOperator(const String & name);
102105

103106
}

Diff for: dbms/src/Functions/FunctionsMiscellaneous.cpp

+104-15
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919
#include <DataTypes/DataTypeArray.h>
2020
#include <DataTypes/DataTypeDate.h>
2121
#include <DataTypes/DataTypeDateTime.h>
22+
#include <DataTypes/DataTypeSet.h>
2223
#include <DataTypes/DataTypeString.h>
2324
#include <DataTypes/DataTypeTuple.h>
25+
#include <DataTypes/DataTypeNullable.h>
2426
#include <DataTypes/DataTypesNumber.h>
2527
#include <DataTypes/DataTypeEnum.h>
2628
#include <DataTypes/NumberTraits.h>
@@ -681,34 +683,44 @@ class FunctionMaterialize : public IFunction
681683
}
682684
};
683685

684-
template <bool negative, bool global>
686+
template <bool negative, bool global, bool ignore_null>
685687
struct FunctionInName;
686688
template <>
687-
struct FunctionInName<false, false>
689+
struct FunctionInName<false, false, true>
688690
{
689691
static constexpr auto name = "in";
690692
};
691693
template <>
692-
struct FunctionInName<false, true>
694+
struct FunctionInName<false, false, false>
695+
{
696+
static constexpr auto name = "tidbIn";
697+
};
698+
template <>
699+
struct FunctionInName<false, true, true>
693700
{
694701
static constexpr auto name = "globalIn";
695702
};
696703
template <>
697-
struct FunctionInName<true, false>
704+
struct FunctionInName<true, false, true>
698705
{
699706
static constexpr auto name = "notIn";
700707
};
701708
template <>
702-
struct FunctionInName<true, true>
709+
struct FunctionInName<true, false, false>
710+
{
711+
static constexpr auto name = "tidbNotIn";
712+
};
713+
template <>
714+
struct FunctionInName<true, true, true>
703715
{
704716
static constexpr auto name = "globalNotIn";
705717
};
706718

707-
template <bool negative, bool global>
719+
template <bool negative, bool global, bool ignore_null>
708720
class FunctionIn : public IFunction
709721
{
710722
public:
711-
static constexpr auto name = FunctionInName<negative, global>::name;
723+
static constexpr auto name = FunctionInName<negative, global, ignore_null>::name;
712724
static FunctionPtr create(const Context &)
713725
{
714726
return std::make_shared<FunctionIn>();
@@ -724,9 +736,27 @@ class FunctionIn : public IFunction
724736
return 2;
725737
}
726738

727-
DataTypePtr getReturnTypeImpl(const DataTypes & /*arguments*/) const override
739+
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
728740
{
729-
return std::make_shared<DataTypeUInt8>();
741+
if constexpr (ignore_null)
742+
return std::make_shared<DataTypeUInt8>();
743+
744+
auto type = removeNullable(arguments[0].type);
745+
if (typeid_cast<const DataTypeTuple *>(type.get()))
746+
throw Exception("Illegal type (" + arguments[0].type->getName() + ") of 1 argument of function " + getName(),
747+
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
748+
if (!typeid_cast<const DataTypeSet *>(arguments[1].type.get()))
749+
throw Exception("Illegal type (" + arguments[1].type->getName() + ") of 2 argument of function " + getName(),
750+
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
751+
bool return_nullable = arguments[0].type->isNullable();
752+
ColumnPtr column_set_ptr = arguments[1].column;
753+
auto * column_set = typeid_cast<const ColumnSet *>(&*column_set_ptr);
754+
return_nullable |= column_set->getData()->containsNullValue();
755+
756+
if (return_nullable)
757+
return makeNullable(std::make_shared<DataTypeUInt8>());
758+
else
759+
return std::make_shared<DataTypeUInt8>();
730760
}
731761

732762
bool useDefaultImplementationForNulls() const override
@@ -736,6 +766,16 @@ class FunctionIn : public IFunction
736766

737767
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override
738768
{
769+
const ColumnWithTypeAndName & left_arg = block.getByPosition(arguments[0]);
770+
if constexpr (!ignore_null)
771+
{
772+
if (left_arg.type->onlyNull())
773+
{
774+
block.getByPosition(result).column =
775+
block.getByPosition(result).type->createColumnConst(block.rows(),Null());
776+
return;
777+
}
778+
}
739779
/// Second argument must be ColumnSet.
740780
ColumnPtr column_set_ptr = block.getByPosition(arguments[1]).column;
741781
const ColumnSet * column_set = typeid_cast<const ColumnSet *>(&*column_set_ptr);
@@ -746,7 +786,6 @@ class FunctionIn : public IFunction
746786
Block block_of_key_columns;
747787

748788
/// First argument may be tuple or single column.
749-
const ColumnWithTypeAndName & left_arg = block.getByPosition(arguments[0]);
750789
const ColumnTuple * tuple = typeid_cast<const ColumnTuple *>(left_arg.column.get());
751790
const ColumnConst * const_tuple = checkAndGetColumnConst<ColumnTuple>(left_arg.column.get());
752791
const DataTypeTuple * type_tuple = typeid_cast<const DataTypeTuple *>(left_arg.type.get());
@@ -769,7 +808,55 @@ class FunctionIn : public IFunction
769808
else
770809
block_of_key_columns.insert(left_arg);
771810

772-
block.getByPosition(result).column = column_set->getData()->execute(block_of_key_columns, negative);
811+
if constexpr (ignore_null)
812+
{
813+
block.getByPosition(result).column = column_set->getData()->execute(block_of_key_columns, negative);
814+
}
815+
else
816+
{
817+
bool set_contains_null = column_set->getData()->containsNullValue();
818+
bool return_nullable = left_arg.type->isNullable() || set_contains_null;
819+
if (return_nullable)
820+
{
821+
auto nested_res = column_set->getData()->execute(block_of_key_columns, negative);
822+
if (left_arg.column->isColumnNullable())
823+
{
824+
ColumnPtr result_null_map_column = dynamic_cast<const ColumnNullable &>(*left_arg.column).getNullMapColumnPtr();
825+
if (set_contains_null)
826+
{
827+
MutableColumnPtr mutable_result_null_map_column = (*std::move(result_null_map_column)).mutate();
828+
NullMap & result_null_map = dynamic_cast<ColumnUInt8 &>(*mutable_result_null_map_column).getData();
829+
auto uint8_column = checkAndGetColumn<ColumnUInt8>(nested_res.get());
830+
const auto & data = uint8_column->getData();
831+
for (size_t i = 0, size = result_null_map.size(); i < size; i++)
832+
{
833+
if (data[i] == negative)
834+
result_null_map[i] = 1;
835+
}
836+
result_null_map_column = std::move(mutable_result_null_map_column);
837+
}
838+
block.getByPosition(result).column = ColumnNullable::create(nested_res, result_null_map_column);
839+
}
840+
else
841+
{
842+
auto col_null_map = ColumnUInt8::create();
843+
ColumnUInt8::Container & vec_null_map = col_null_map->getData();
844+
vec_null_map.assign(block.rows(), (UInt8) 0);
845+
auto uint8_column = checkAndGetColumn<ColumnUInt8>(nested_res.get());
846+
const auto & data = uint8_column->getData();
847+
for (size_t i = 0, size = vec_null_map.size(); i < size; i++)
848+
{
849+
if (data[i] == negative)
850+
vec_null_map[i] = 1;
851+
}
852+
block.getByPosition(result).column = ColumnNullable::create(nested_res, std::move(col_null_map));
853+
}
854+
}
855+
else
856+
{
857+
block.getByPosition(result).column = column_set->getData()->execute(block_of_key_columns, negative);
858+
}
859+
}
773860
}
774861
};
775862

@@ -1865,10 +1952,12 @@ void registerFunctionsMiscellaneous(FunctionFactory & factory)
18651952
factory.registerFunction<FunctionBar>();
18661953
factory.registerFunction<FunctionHasColumnInTable>();
18671954

1868-
factory.registerFunction<FunctionIn<false, false>>();
1869-
factory.registerFunction<FunctionIn<false, true>>();
1870-
factory.registerFunction<FunctionIn<true, false>>();
1871-
factory.registerFunction<FunctionIn<true, true>>();
1955+
factory.registerFunction<FunctionIn<false, false, true>>();
1956+
factory.registerFunction<FunctionIn<false, true, true>>();
1957+
factory.registerFunction<FunctionIn<true, false, true>>();
1958+
factory.registerFunction<FunctionIn<true, true, true>>();
1959+
factory.registerFunction<FunctionIn<true, false, false>>();
1960+
factory.registerFunction<FunctionIn<false, false, false>>();
18721961

18731962
factory.registerFunction<FunctionIsFinite>();
18741963
factory.registerFunction<FunctionIsInfinite>();

Diff for: dbms/src/Interpreters/ExpressionAnalyzer.cpp

+1-10
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
#include <Parsers/formatAST.h>
5353

5454
#include <Functions/FunctionFactory.h>
55+
#include <Functions/FunctionHelpers.h>
5556
#include <Functions/IFunction.h>
5657

5758
#include <ext/range.h>
@@ -132,16 +133,6 @@ const std::unordered_set<String> possibly_injective_function_names
132133
namespace
133134
{
134135

135-
bool functionIsInOperator(const String & name)
136-
{
137-
return name == "in" || name == "notIn";
138-
}
139-
140-
bool functionIsInOrGlobalInOperator(const String & name)
141-
{
142-
return name == "in" || name == "notIn" || name == "globalIn" || name == "globalNotIn";
143-
}
144-
145136
void removeDuplicateColumns(NamesAndTypesList & columns)
146137
{
147138
std::set<String> names;

Diff for: dbms/src/Interpreters/Set.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,8 @@ void Set::createFromAST(const DataTypes & types, ASTPtr node, const Context & co
222222

223223
if (!value.isNull())
224224
columns[0]->insert(value);
225+
else
226+
setContainsNullValue(true);
225227
}
226228
else if (ASTFunction * func = typeid_cast<ASTFunction *>(elem.get()))
227229
{
@@ -277,6 +279,10 @@ std::vector<const tipb::Expr *> Set::createFromDAGExpr(const DataTypes & types,
277279
MutableColumns columns = header.cloneEmptyColumns();
278280
std::vector<const tipb::Expr *> remainingExprs;
279281

282+
// if left arg is null constant, just return without decode children expr
283+
if (types[0]->onlyNull())
284+
return remainingExprs;
285+
280286
for (int i = 1; i < expr.children_size(); i++)
281287
{
282288
auto & child = expr.children(i);
@@ -292,6 +298,8 @@ std::vector<const tipb::Expr *> Set::createFromDAGExpr(const DataTypes & types,
292298

293299
if (!value.isNull())
294300
columns[0]->insert(value);
301+
else
302+
setContainsNullValue(true);
295303
}
296304

297305
Block block = header.cloneWithColumns(std::move(columns));

Diff for: dbms/src/Interpreters/Set.h

+5
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ class Set
7878

7979
SetElements & getSetElements() { return *set_elements.get(); }
8080

81+
void setContainsNullValue(bool contains_null_value_) { contains_null_value = contains_null_value_; }
82+
bool containsNullValue() const { return contains_null_value; }
83+
8184
private:
8285
size_t keys_size;
8386
Sizes key_sizes;
@@ -109,6 +112,8 @@ class Set
109112
/// Limitations on the maximum size of the set
110113
SizeLimits limits;
111114

115+
bool contains_null_value = false;
116+
112117
/// If in the left part columns contains the same types as the elements of the set.
113118
void executeOrdinary(
114119
const ColumnRawPtrs & key_columns,

0 commit comments

Comments
 (0)