diff --git a/velox/duckdb/conversion/DuckConversion.cpp b/velox/duckdb/conversion/DuckConversion.cpp index 81aabe28a67..442798445ef 100644 --- a/velox/duckdb/conversion/DuckConversion.cpp +++ b/velox/duckdb/conversion/DuckConversion.cpp @@ -196,6 +196,8 @@ variant duckValueToVariant(const Value& val) { return variant(val.GetValue()); case LogicalTypeId::BLOB: return variant::binary(val.GetValue()); + case LogicalTypeId::DATE: + return variant::date(val.GetValue<::duckdb::date_t>().days); default: throw std::runtime_error( "unsupported type for duckdb value -> velox variant conversion: " + diff --git a/velox/duckdb/conversion/DuckParser.cpp b/velox/duckdb/conversion/DuckParser.cpp index 43ea24799e7..ab8a9bd50cc 100644 --- a/velox/duckdb/conversion/DuckParser.cpp +++ b/velox/duckdb/conversion/DuckParser.cpp @@ -358,8 +358,21 @@ std::shared_ptr parseOperatorExpr( TypePtr valueType = UNKNOWN(); for (auto i = 0; i < numValues; i++) { - if (auto constantExpr = dynamic_cast( - operExpr.children[i + 1].get())) { + auto valueExpr = operExpr.children[i + 1].get(); + if (const auto castExpr = dynamic_cast(valueExpr)) { + if (castExpr->child->GetExpressionType() == + ExpressionType::VALUE_CONSTANT) { + auto constExpr = + dynamic_cast(castExpr->child.get()); + auto value = + constExpr->value.CastAs(castExpr->cast_type, !castExpr->try_cast); + values.emplace_back(duckValueToVariant(value)); + valueType = toVeloxType(castExpr->cast_type); + continue; + } + } + + if (auto constantExpr = dynamic_cast(valueExpr)) { auto& value = constantExpr->value; if (options.parseDecimalAsDouble && value.type().id() == duckdb::LogicalTypeId::DECIMAL) { @@ -369,9 +382,10 @@ std::shared_ptr parseOperatorExpr( if (!value.IsNull()) { valueType = toVeloxType(value.type()); } - } else { - VELOX_UNSUPPORTED("IN list values need to be constant"); + continue; } + + VELOX_UNSUPPORTED("IN list values need to be constant"); } std::vector> params; diff --git a/velox/functions/prestosql/InPredicate.cpp b/velox/functions/prestosql/InPredicate.cpp index b740e442230..f59940bfb52 100644 --- a/velox/functions/prestosql/InPredicate.cpp +++ b/velox/functions/prestosql/InPredicate.cpp @@ -113,6 +113,29 @@ std::pair, bool> createBytesValuesFilter( return {std::make_unique(values, nullAllowed), false}; } +std::pair, bool> createDateValuesFilter( + const std::vector& inputArgs) { + auto valuesPair = toValues(inputArgs); + if (!valuesPair.has_value()) { + return {nullptr, false}; + } + + const auto& values = valuesPair.value().first; + bool nullAllowed = valuesPair.value().second; + + if (values.empty() && nullAllowed) { + return {nullptr, true}; + } + VELOX_USER_CHECK( + !values.empty(), + "IN predicate expects at least one non-null value in the in-list"); + std::vector dayValues; + for (auto date : values) { + dayValues.push_back(date.days()); + } + return {common::createBigintValues(dayValues, nullAllowed), false}; +} + class InPredicate : public exec::VectorFunction { public: explicit InPredicate(std::unique_ptr filter, bool alwaysNull) @@ -143,6 +166,9 @@ class InPredicate : public exec::VectorFunction { case TypeKind::VARBINARY: filter = createBytesValuesFilter(inputArgs); break; + case TypeKind::DATE: + filter = createDateValuesFilter(inputArgs); + break; case TypeKind::UNKNOWN: filter = {nullptr, true}; break; @@ -196,6 +222,11 @@ class InPredicate : public exec::VectorFunction { return filter_->testBytes(value.data(), value.size()); }); break; + case TypeKind::DATE: + applyTyped(rows, input, context, result, [&](Date value) { + return filter_->testInt64(value.days()); + }); + break; default: VELOX_UNSUPPORTED( "Unsupported input type for the IN predicate: {}", @@ -207,7 +238,13 @@ class InPredicate : public exec::VectorFunction { // tinyint|smallint|integer|bigint|varchar... -> boolean std::vector> signatures; for (auto& type : - {"tinyint", "smallint", "integer", "bigint", "varchar", "varbinary"}) { + {"tinyint", + "smallint", + "integer", + "bigint", + "varchar", + "varbinary", + "date"}) { signatures.emplace_back(exec::FunctionSignatureBuilder() .returnType("boolean") .argumentType(type) diff --git a/velox/functions/prestosql/tests/InPredicateTest.cpp b/velox/functions/prestosql/tests/InPredicateTest.cpp index 5b5662a66ce..24b0ab8dfc0 100644 --- a/velox/functions/prestosql/tests/InPredicateTest.cpp +++ b/velox/functions/prestosql/tests/InPredicateTest.cpp @@ -318,6 +318,33 @@ TEST_F(InPredicateTest, varbinary) { assertEqualVectors(makeConstant(true, input->size()), result); } +TEST_F(InPredicateTest, date) { + auto dateValue = Date(); + parseTo("2000-01-01", dateValue); + + auto input = makeRowVector({ + makeNullableFlatVector({dateValue}, DATE()), + }); + + assertEqualVectors( + makeConstant(true, input->size()), + evaluate("c0 IN (DATE '2000-01-01')", input)); + + assertEqualVectors( + makeConstant(false, input->size()), + evaluate("c0 IN (DATE '2000-02-01')", input)); + + assertEqualVectors( + makeConstant(false, input->size()), + evaluate("c0 IN (DATE '2000-02-01', DATE '2000-03-04')", input)); + + assertEqualVectors( + makeConstant(true, input->size()), + evaluate( + "c0 IN (DATE '2000-02-01', DATE '2000-03-04', DATE '2000-01-01')", + input)); +} + TEST_F(InPredicateTest, reusableResult) { std::string predicate = "c0 IN (1, 2)"; auto input = makeRowVector({makeNullableFlatVector({0, 1, 2, 3})});