diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index ef92ae09fe7..24dca59e63a 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -99,6 +99,8 @@ namespace { std::string PrintDatum(const Datum& datum) { if (datum.is_scalar()) { + if (!datum.scalar()->is_valid) return "null"; + switch (datum.type()->id()) { case Type::STRING: case Type::LARGE_STRING: @@ -114,6 +116,7 @@ std::string PrintDatum(const Datum& datum) { default: break; } + return datum.scalar()->ToString(); } return datum.ToString(); @@ -688,27 +691,27 @@ std::vector GuaranteeConjunctionMembers( // conjunction_members Status ExtractKnownFieldValuesImpl( std::vector* conjunction_members, - std::unordered_map* known_values) { - auto unconsumed_end = std::partition( - conjunction_members->begin(), conjunction_members->end(), - [](const Expression& expr) { - // search for an equality conditions between a field and a literal - auto call = expr.call(); - if (!call) return true; - - if (call->function_name == "equal") { - auto ref = call->arguments[0].field_ref(); - auto lit = call->arguments[1].literal(); - return !(ref && lit); - } - - if (call->function_name == "is_null" || call->function_name == "is_valid") { - auto ref = call->arguments[0].field_ref(); - return !ref; - } - - return true; - }); + std::unordered_map* known_values) { + auto unconsumed_end = + std::partition(conjunction_members->begin(), conjunction_members->end(), + [](const Expression& expr) { + // search for an equality conditions between a field and a literal + auto call = expr.call(); + if (!call) return true; + + if (call->function_name == "equal") { + auto ref = call->arguments[0].field_ref(); + auto lit = call->arguments[1].literal(); + return !(ref && lit); + } + + if (call->function_name == "is_null") { + auto ref = call->arguments[0].field_ref(); + return !ref; + } + + return true; + }); for (auto it = unconsumed_end; it != conjunction_members->end(); ++it) { auto call = CallNotNull(*it); @@ -719,10 +722,7 @@ Status ExtractKnownFieldValuesImpl( known_values->emplace(*ref, *lit); } else if (call->function_name == "is_null") { auto ref = call->arguments[0].field_ref(); - known_values->emplace(*ref, false); - } else if (call->function_name == "is_valid") { - auto ref = call->arguments[0].field_ref(); - known_values->emplace(*ref, true); + known_values->emplace(*ref, std::make_shared()); } } @@ -733,16 +733,16 @@ Status ExtractKnownFieldValuesImpl( } // namespace -Result> -ExtractKnownFieldValues(const Expression& guaranteed_true_predicate) { +Result> ExtractKnownFieldValues( + const Expression& guaranteed_true_predicate) { auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate); - std::unordered_map known_values; + std::unordered_map known_values; RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values)); return known_values; } Result ReplaceFieldsWithKnownValues( - const std::unordered_map& known_values, + const std::unordered_map& known_values, Expression expr) { if (!expr.IsBound()) { return Status::Invalid( @@ -755,11 +755,7 @@ Result ReplaceFieldsWithKnownValues( if (auto ref = expr.field_ref()) { auto it = known_values.find(*ref); if (it != known_values.end()) { - const auto& known_value = it->second; - if (!known_value.concrete()) { - return expr; - } - auto lit = known_value.datum; + Datum lit = it->second; if (expr.type()->id() == Type::DICTIONARY) { if (lit.is_scalar()) { // FIXME the "right" way to support this is adding support for scalars to @@ -779,22 +775,6 @@ Result ReplaceFieldsWithKnownValues( ARROW_ASSIGN_OR_RAISE(lit, compute::Cast(lit, expr.type())); return literal(std::move(lit)); } - } else if (auto call = expr.call()) { - if (call->function_name == "is_null") { - if (auto ref = call->arguments[0].field_ref()) { - auto it = known_values.find(*ref); - if (it != known_values.end()) { - return literal(!it->second.valid); - } - } - } else if (call->function_name == "is_valid") { - if (auto ref = call->arguments[0].field_ref()) { - auto it = known_values.find(*ref); - if (it != known_values.end()) { - return literal(it->second.valid); - } - } - } } return expr; }, @@ -971,7 +951,7 @@ Result SimplifyWithGuarantee(Expression expr, const Expression& guaranteed_true_predicate) { auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate); - std::unordered_map known_values; + std::unordered_map known_values; RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values)); ARROW_ASSIGN_OR_RAISE(expr, diff --git a/cpp/src/arrow/dataset/expression.h b/cpp/src/arrow/dataset/expression.h index 1bbcb471015..d8d4093243a 100644 --- a/cpp/src/arrow/dataset/expression.h +++ b/cpp/src/arrow/dataset/expression.h @@ -162,27 +162,10 @@ Expression call(std::string function, std::vector arguments, ARROW_DS_EXPORT std::vector FieldsInExpression(const Expression&); -/// Represents either a concrete value or a hint that a field is valid/invalid -struct KnownFieldValue { - Datum datum; - bool valid; - - KnownFieldValue() : datum(), valid(false) {} - KnownFieldValue(const Datum& datum) // NOLINT implicit conversion - : datum(datum), valid(datum.length() != datum.null_count()) {} - KnownFieldValue(bool is_valid) // NOLINT implicit conversion - : datum(), valid(is_valid) {} - - inline bool concrete() const { return datum.kind() != Datum::Kind::NONE; } - bool operator==(const KnownFieldValue& other) const { - return datum == other.datum && valid == other.valid; - } -}; - /// Assemble a mapping from field references to known values. ARROW_DS_EXPORT -Result> -ExtractKnownFieldValues(const Expression& guaranteed_true_predicate); +Result> ExtractKnownFieldValues( + const Expression& guaranteed_true_predicate); /// \defgroup expression-passes Functions for modification of Expressions /// @@ -211,8 +194,7 @@ Result FoldConstants(Expression); /// Simplify Expressions by replacing with known values of the fields which it references. ARROW_DS_EXPORT Result ReplaceFieldsWithKnownValues( - const std::unordered_map& known_values, - Expression); + const std::unordered_map& known_values, Expression); /// Simplify an expression by replacing subexpressions based on a guarantee: /// a boolean expression which is guaranteed to evaluate to `true`. For example, this is diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index d8077be73f5..b58471e763b 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -676,9 +676,8 @@ TEST(Expression, FoldConstantsBoolean) { TEST(Expression, ExtractKnownFieldValues) { struct { - void operator()( - Expression guarantee, - std::unordered_map expected) { + void operator()(Expression guarantee, + std::unordered_map expected) { ASSERT_OK_AND_ASSIGN(auto actual, ExtractKnownFieldValues(guarantee)); EXPECT_THAT(actual, UnorderedElementsAreArray(expected)) << " guarantee: " << guarantee.ToString(); @@ -726,20 +725,20 @@ TEST(Expression, ExtractKnownFieldValues) { } TEST(Expression, ReplaceFieldsWithKnownValues) { - auto ExpectReplacesTo = [](Expression expr, - const std::unordered_map& known_values, - Expression unbound_expected) { - ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*kBoringSchema)); - ASSERT_OK_AND_ASSIGN(auto expected, unbound_expected.Bind(*kBoringSchema)); - ASSERT_OK_AND_ASSIGN(auto replaced, ReplaceFieldsWithKnownValues(known_values, expr)); + auto ExpectReplacesTo = + [](Expression expr, + const std::unordered_map& known_values, + Expression unbound_expected) { + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(auto expected, unbound_expected.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(auto replaced, + ReplaceFieldsWithKnownValues(known_values, expr)); - EXPECT_EQ(replaced, expected); - ExpectIdenticalIfUnchanged(replaced, expr); - }; + EXPECT_EQ(replaced, expected); + ExpectIdenticalIfUnchanged(replaced, expr); + }; - std::unordered_map i32_is_3{ - {"i32", Datum(3)}}; + std::unordered_map i32_is_3{{"i32", Datum(3)}}; ExpectReplacesTo(literal(1), i32_is_3, literal(1)); @@ -772,13 +771,18 @@ TEST(Expression, ReplaceFieldsWithKnownValues) { literal(2), })); - std::unordered_map a_valid_b_invalid{ - {"a", true}, {"b", false}}; + std::unordered_map i32_valid_str_null{ + {"i32", Datum(3)}, {"str", MakeNullScalar(utf8())}}; + + ExpectReplacesTo(is_null(field_ref("i32")), i32_valid_str_null, is_null(literal(3))); - ExpectReplacesTo(is_null(field_ref("a")), a_valid_b_invalid, literal(false)); - ExpectReplacesTo(is_valid(field_ref("a")), a_valid_b_invalid, literal(true)); - ExpectReplacesTo(is_null(field_ref("b")), a_valid_b_invalid, literal(true)); - ExpectReplacesTo(is_valid(field_ref("b")), a_valid_b_invalid, literal(false)); + ExpectReplacesTo(is_valid(field_ref("i32")), i32_valid_str_null, is_valid(literal(3))); + + ExpectReplacesTo(is_null(field_ref("str")), i32_valid_str_null, + is_null(null_literal(utf8()))); + + ExpectReplacesTo(is_valid(field_ref("str")), i32_valid_str_null, + is_valid(null_literal(utf8()))); } struct { @@ -1042,7 +1046,7 @@ TEST(Expression, SimplifyWithGuarantee) { Simplify{is_valid(field_ref("i32"))} .WithGuarantee(is_valid(field_ref("i32"))) - .Expect(literal(true)); + .Expect(is_valid(field_ref("i32"))); } TEST(Expression, SimplifyThenExecute) { diff --git a/cpp/src/arrow/dataset/partition.cc b/cpp/src/arrow/dataset/partition.cc index e9c198e3398..de7eb8d271d 100644 --- a/cpp/src/arrow/dataset/partition.cc +++ b/cpp/src/arrow/dataset/partition.cc @@ -74,31 +74,20 @@ Status KeyValuePartitioning::SetDefaultValuesFromKeys(const Expression& expr, RecordBatchProjector* projector) { ARROW_ASSIGN_OR_RAISE(auto known_values, ExtractKnownFieldValues(expr)); for (const auto& ref_value : known_values) { - const auto& known_value = ref_value.second; - if (known_value.concrete() && !known_value.datum.is_scalar()) { - return Status::Invalid("non-scalar partition key ", known_value.datum.ToString()); + if (!ref_value.second.is_scalar()) { + return Status::Invalid("non-scalar partition key ", ref_value.second.ToString()); } ARROW_ASSIGN_OR_RAISE(auto match, ref_value.first.FindOneOrNone(*projector->schema())); if (match.empty()) continue; - - const auto& field = projector->schema()->field(match[0]); - if (known_value.concrete()) { - RETURN_NOT_OK(projector->SetDefaultValue(match, known_value.datum.scalar())); - } else if (known_value.valid) { - // We know some information about the value but nothing concrete enough to set. Can - // happen if expression is something like is_valid(field_ref("a")) - continue; - } else { - RETURN_NOT_OK(projector->SetDefaultValue(match, MakeNullScalar(field->type()))); - } + RETURN_NOT_OK(projector->SetDefaultValue(match, ref_value.second.scalar())); } return Status::OK(); } -Expression ConjunctionFromGroupingRow(Scalar* row) { +inline Expression ConjunctionFromGroupingRow(Scalar* row) { ScalarVector* values = &checked_cast(row)->value; std::vector equality_expressions(values->size()); for (size_t i = 0; i < values->size(); ++i) { @@ -213,34 +202,37 @@ Result KeyValuePartitioning::Format(const Expression& expr) const { ARROW_ASSIGN_OR_RAISE(auto known_values, ExtractKnownFieldValues(expr)); for (const auto& ref_value : known_values) { - const auto& known_value = ref_value.second; - if (known_value.concrete() && !known_value.datum.is_scalar()) { - return Status::Invalid("non-scalar partition key ", known_value.datum.ToString()); + if (!ref_value.second.is_scalar()) { + return Status::Invalid("non-scalar partition key ", ref_value.second.ToString()); } ARROW_ASSIGN_OR_RAISE(auto match, ref_value.first.FindOneOrNone(*schema_)); if (match.empty()) continue; - const auto& field = schema_->field(match[0]); - - if (known_value.concrete()) { - auto value = known_value.datum.scalar(); - if (!value->type->Equals(field->type())) { - return Status::TypeError("scalar ", value->ToString(), " (of type ", *value->type, - ") is invalid for ", field->ToString()); - } + auto value = ref_value.second.scalar(); - if (value->type->id() == Type::DICTIONARY) { - ARROW_ASSIGN_OR_RAISE( - value, checked_cast(*value).GetEncodedValue()); + const auto& field = schema_->field(match[0]); + if (!value->type->Equals(field->type())) { + if (value->is_valid) { + auto maybe_converted = compute::Cast(value, field->type()); + if (!maybe_converted.ok()) { + return Status::TypeError("Error converting scalar ", value->ToString(), + " (of type ", *value->type, + ") to a partition key for ", field->ToString(), ": ", + maybe_converted.status().message()); + } + value = maybe_converted->scalar(); + } else { + value = MakeNullScalar(field->type()); } + } - values[match[0]] = std::move(value); - } else { - if (!known_value.valid) { - values[match[0]] = MakeNullScalar(field->type()); - } + if (value->type->id() == Type::DICTIONARY) { + ARROW_ASSIGN_OR_RAISE( + value, checked_cast(*value).GetEncodedValue()); } + + values[match[0]] = std::move(value); } return FormatValues(values); diff --git a/cpp/src/arrow/dataset/projector.cc b/cpp/src/arrow/dataset/projector.cc index 2ba679ce6e7..ba0eb2ddff5 100644 --- a/cpp/src/arrow/dataset/projector.cc +++ b/cpp/src/arrow/dataset/projector.cc @@ -23,6 +23,7 @@ #include #include "arrow/array.h" +#include "arrow/compute/cast.h" #include "arrow/dataset/type_fwd.h" #include "arrow/record_batch.h" #include "arrow/result.h" @@ -88,9 +89,18 @@ Status RecordBatchProjector::SetDefaultValue(FieldRef ref, auto field_type = to_->field(index)->type(); if (!field_type->Equals(scalar->type)) { - return Status::TypeError("field ", to_->field(index)->ToString(), - " cannot be materialized from scalar of type ", - *scalar->type); + if (scalar->is_valid) { + auto maybe_converted = compute::Cast(scalar, field_type); + if (!maybe_converted.ok()) { + return Status::TypeError("Field ", to_->field(index)->ToString(), + " cannot be materialized from scalar of type ", + *scalar->type, + ". Cast error: ", maybe_converted.status().message()); + } + scalar = maybe_converted->scalar(); + } else { + scalar = MakeNullScalar(field_type); + } } scalars_[index] = std::move(scalar); diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index 3553c860307..e446bd481b8 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -2351,17 +2351,14 @@ def _get_partition_keys(Expression partition_expression): """ cdef: CExpression expr = partition_expression.unwrap() - pair[CFieldRef, CKnownFieldValue] ref_val + pair[CFieldRef, CDatum] ref_val out = {} for ref_val in GetResultValue(CExtractKnownFieldValues(expr)): assert ref_val.first.name() != nullptr - if ref_val.second.valid: - assert ref_val.second.datum.kind() == DatumType_SCALAR - val = pyarrow_wrap_scalar(ref_val.second.datum.scalar()) - out[frombytes(deref(ref_val.first.name()))] = val.as_py() - else: - out[frombytes(deref(ref_val.first.name()))] = None + assert ref_val.second.kind() == DatumType_SCALAR + val = pyarrow_wrap_scalar(ref_val.second.scalar()) + out[frombytes(deref(ref_val.first.name()))] = val.as_py() return out diff --git a/python/pyarrow/includes/libarrow_dataset.pxd b/python/pyarrow/includes/libarrow_dataset.pxd index 320a3f6035b..90c47ba721b 100644 --- a/python/pyarrow/includes/libarrow_dataset.pxd +++ b/python/pyarrow/includes/libarrow_dataset.pxd @@ -314,14 +314,7 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: const CExpression& partition_expression, CRecordBatchProjector* projector) - cdef cppclass CKnownFieldValue "arrow::dataset::KnownFieldValue": - CDatum datum - c_bool valid - CKnownFieldValue(CDatum datum) - CKnownFieldValue(c_bool valid) - c_bool operator==(const CKnownFieldValue&) const - - cdef CResult[unordered_map[CFieldRef, CKnownFieldValue, CFieldRefHash]] \ + cdef CResult[unordered_map[CFieldRef, CDatum, CFieldRefHash]] \ CExtractKnownFieldValues "arrow::dataset::ExtractKnownFieldValues"( const CExpression& partition_expression) diff --git a/python/pyarrow/public-api.pxi b/python/pyarrow/public-api.pxi index aa738f9aaea..998af512c55 100644 --- a/python/pyarrow/public-api.pxi +++ b/python/pyarrow/public-api.pxi @@ -251,6 +251,9 @@ cdef api object pyarrow_wrap_scalar(const shared_ptr[CScalar]& sp_scalar): if data_type == NULL: raise ValueError('Scalar data type was NULL') + if data_type.id() == _Type_NA: + return _NULL + if data_type.id() not in _scalar_classes: raise ValueError('Scalar type not supported') diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index 0312b9b56cf..2fab1f23da4 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -26,6 +26,7 @@ import pytest import pyarrow as pa +import pyarrow.csv import pyarrow.fs as fs from pyarrow.tests.util import change_cwd, _filesystem_uri