Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions velox/duckdb/conversion/DuckConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ variant duckValueToVariant(const Value& val) {
return variant(val.GetValue<std::string>());
case LogicalTypeId::BLOB:
return variant::binary(val.GetValue<std::string>());
case LogicalTypeId::DATE:
return variant::date(val.GetValue<::duckdb::date_t>().days);
default:
throw std::runtime_error(
"unsupported type for duckdb value -> velox variant conversion: " +
Expand Down
22 changes: 18 additions & 4 deletions velox/duckdb/conversion/DuckParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,21 @@ std::shared_ptr<const core::IExpr> parseOperatorExpr(

TypePtr valueType = UNKNOWN();
for (auto i = 0; i < numValues; i++) {
if (auto constantExpr = dynamic_cast<ConstantExpression*>(
operExpr.children[i + 1].get())) {
auto valueExpr = operExpr.children[i + 1].get();
if (const auto castExpr = dynamic_cast<CastExpression*>(valueExpr)) {
if (castExpr->child->GetExpressionType() ==
ExpressionType::VALUE_CONSTANT) {
auto constExpr =
dynamic_cast<ConstantExpression*>(castExpr->child.get());
auto value =
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check constExpr is non-null before de-referencing it in the next line. We should ensure that cast is not over any other expression here.

constExpr->value.CastAs(castExpr->cast_type, !castExpr->try_cast);
values.emplace_back(duckValueToVariant(value));
valueType = toVeloxType(castExpr->cast_type);
continue;
}
}

if (auto constantExpr = dynamic_cast<ConstantExpression*>(valueExpr)) {
auto& value = constantExpr->value;
if (options.parseDecimalAsDouble &&
value.type().id() == duckdb::LogicalTypeId::DECIMAL) {
Expand All @@ -369,9 +382,10 @@ std::shared_ptr<const core::IExpr> parseOperatorExpr(
if (!value.IsNull()) {
valueType = toVeloxType(value.type());
}
} else {
VELOX_UNSUPPORTED("IN list values need to be constant");
continue;
}

VELOX_UNSUPPORTED("IN list values need to be constant");
}

std::vector<std::shared_ptr<const core::IExpr>> params;
Expand Down
39 changes: 38 additions & 1 deletion velox/functions/prestosql/InPredicate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,29 @@ std::pair<std::unique_ptr<common::Filter>, bool> createBytesValuesFilter(
return {std::make_unique<common::BytesValues>(values, nullAllowed), false};
}

std::pair<std::unique_ptr<common::Filter>, bool> createDateValuesFilter(
const std::vector<exec::VectorFunctionArg>& inputArgs) {
auto valuesPair = toValues<Date, Date>(inputArgs);
if (!valuesPair.has_value()) {
return {nullptr, false};
}

const auto& values = valuesPair.value().first;
bool nullAllowed = valuesPair.value().second;

if (values.empty() && nullAllowed) {
return {nullptr, true};
}
VELOX_USER_CHECK(
!values.empty(),
"IN predicate expects at least one non-null value in the in-list");
std::vector<int64_t> dayValues;
for (auto date : values) {
dayValues.push_back(date.days());
}
return {common::createBigintValues(dayValues, nullAllowed), false};
}

class InPredicate : public exec::VectorFunction {
public:
explicit InPredicate(std::unique_ptr<common::Filter> filter, bool alwaysNull)
Expand Down Expand Up @@ -143,6 +166,9 @@ class InPredicate : public exec::VectorFunction {
case TypeKind::VARBINARY:
filter = createBytesValuesFilter(inputArgs);
break;
case TypeKind::DATE:
filter = createDateValuesFilter(inputArgs);
break;
case TypeKind::UNKNOWN:
filter = {nullptr, true};
break;
Expand Down Expand Up @@ -196,6 +222,11 @@ class InPredicate : public exec::VectorFunction {
return filter_->testBytes(value.data(), value.size());
});
break;
case TypeKind::DATE:
applyTyped<Date>(rows, input, context, result, [&](Date value) {
return filter_->testInt64(value.days());
});
break;
default:
VELOX_UNSUPPORTED(
"Unsupported input type for the IN predicate: {}",
Expand All @@ -207,7 +238,13 @@ class InPredicate : public exec::VectorFunction {
// tinyint|smallint|integer|bigint|varchar... -> boolean
std::vector<std::shared_ptr<exec::FunctionSignature>> signatures;
for (auto& type :
{"tinyint", "smallint", "integer", "bigint", "varchar", "varbinary"}) {
{"tinyint",
"smallint",
"integer",
"bigint",
"varchar",
"varbinary",
"date"}) {
signatures.emplace_back(exec::FunctionSignatureBuilder()
.returnType("boolean")
.argumentType(type)
Expand Down
27 changes: 27 additions & 0 deletions velox/functions/prestosql/tests/InPredicateTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,33 @@ TEST_F(InPredicateTest, varbinary) {
assertEqualVectors(makeConstant(true, input->size()), result);
}

TEST_F(InPredicateTest, date) {
auto dateValue = Date();
parseTo("2000-01-01", dateValue);

auto input = makeRowVector({
makeNullableFlatVector<Date>({dateValue}, DATE()),
});

assertEqualVectors(
makeConstant(true, input->size()),
evaluate("c0 IN (DATE '2000-01-01')", input));

assertEqualVectors(
makeConstant(false, input->size()),
evaluate("c0 IN (DATE '2000-02-01')", input));

assertEqualVectors(
makeConstant(false, input->size()),
evaluate("c0 IN (DATE '2000-02-01', DATE '2000-03-04')", input));

assertEqualVectors(
makeConstant(true, input->size()),
evaluate(
"c0 IN (DATE '2000-02-01', DATE '2000-03-04', DATE '2000-01-01')",
input));
}

TEST_F(InPredicateTest, reusableResult) {
std::string predicate = "c0 IN (1, 2)";
auto input = makeRowVector({makeNullableFlatVector<int32_t>({0, 1, 2, 3})});
Expand Down