diff --git a/cpp/src/arrow/dataset/file_parquet.cc b/cpp/src/arrow/dataset/file_parquet.cc index 33f74637ee7..e036c4826c8 100644 --- a/cpp/src/arrow/dataset/file_parquet.cc +++ b/cpp/src/arrow/dataset/file_parquet.cc @@ -131,7 +131,6 @@ static Result GetSchemaManifest( static std::shared_ptr MakeMinMaxScalar(std::shared_ptr min, std::shared_ptr max) { - DCHECK(min->type->Equals(max->type)); return std::make_shared(ScalarVector{min, max}, struct_({ field("min", min->type), @@ -171,6 +170,15 @@ static std::shared_ptr ColumnChunkStatisticsAsStructScalar( return nullptr; } + auto maybe_min = min->CastTo(field->type()); + auto maybe_max = max->CastTo(field->type()); + if (maybe_min.ok() && maybe_max.ok()) { + min = maybe_min.MoveValueUnsafe(); + max = maybe_max.MoveValueUnsafe(); + } else { + return nullptr; + } + return MakeMinMaxScalar(std::move(min), std::move(max)); } diff --git a/cpp/src/arrow/dataset/file_parquet_test.cc b/cpp/src/arrow/dataset/file_parquet_test.cc index 8645cbaf4dd..d0a22e98382 100644 --- a/cpp/src/arrow/dataset/file_parquet_test.cc +++ b/cpp/src/arrow/dataset/file_parquet_test.cc @@ -480,21 +480,20 @@ TEST_F(TestParquetFileFormat, PredicatePushdownRowGroupFragments) { } TEST_F(TestParquetFileFormat, PredicatePushdownRowGroupFragmentsUsingStringColumn) { - auto table = - TableFromJSON(schema({field("x", utf8())}), { - R"([{"x": "a"}, {"x": "a"}])", - R"([{"x": "b"}, {"x": "b"}])", - R"([{"x": "c"}, {"x": "c"}])", - R"([{"x": "a"}, {"x": "b"}])", - }); + auto table = TableFromJSON(schema({field("x", utf8())}), + { + R"([{"x": "a"}])", + R"([{"x": "b"}, {"x": "b"}])", + R"([{"x": "c"}, {"x": "c"}, {"x": "c"}])", + R"([{"x": "a"}, {"x": "b"}, {"x": "c"}, {"x": "d"}])", + }); TableBatchReader reader(*table); auto source = GetFileSource(&reader); opts_ = ScanOptions::Make(reader.schema()); ASSERT_OK_AND_ASSIGN(auto fragment, format_->MakeFragment(*source)); - // TODO(bkietz): support strings in StatisticsAsScalars - // CountRowGroupsInFragment(fragment, {0, 3}, "x"_ == "a"); + CountRowGroupsInFragment(fragment, {0, 3}, "x"_ == "a"); } TEST_F(TestParquetFileFormat, ExplicitRowGroupSelection) { diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc index 968d55c4318..348b8a63508 100644 --- a/cpp/src/arrow/dataset/filter.cc +++ b/cpp/src/arrow/dataset/filter.cc @@ -1179,8 +1179,8 @@ struct TreeEvaluator::Impl { Result kernel(const Datum& left, const Datum& right, ExecContext* ctx)) const { - ARROW_ASSIGN_OR_RAISE(auto lhs, Evaluate(*expr.left_operand())); - ARROW_ASSIGN_OR_RAISE(auto rhs, Evaluate(*expr.right_operand())); + ARROW_ASSIGN_OR_RAISE(Datum lhs, Evaluate(*expr.left_operand())); + ARROW_ASSIGN_OR_RAISE(Datum rhs, Evaluate(*expr.right_operand())); if (lhs.is_scalar()) { ARROW_ASSIGN_OR_RAISE( @@ -1200,7 +1200,7 @@ struct TreeEvaluator::Impl { } Result operator()(const NotExpression& expr) const { - ARROW_ASSIGN_OR_RAISE(auto to_invert, Evaluate(*expr.operand())); + ARROW_ASSIGN_OR_RAISE(Datum to_invert, Evaluate(*expr.operand())); if (IsNullDatum(to_invert)) { return NullDatum(); } @@ -1214,7 +1214,7 @@ struct TreeEvaluator::Impl { } Result operator()(const InExpression& expr) const { - ARROW_ASSIGN_OR_RAISE(auto operand_values, Evaluate(*expr.operand())); + ARROW_ASSIGN_OR_RAISE(Datum operand_values, Evaluate(*expr.operand())); if (IsNullDatum(operand_values)) { return Datum(expr.set()->null_count() != 0); } @@ -1224,7 +1224,7 @@ struct TreeEvaluator::Impl { } Result operator()(const IsValidExpression& expr) const { - ARROW_ASSIGN_OR_RAISE(auto operand_values, Evaluate(*expr.operand())); + ARROW_ASSIGN_OR_RAISE(Datum operand_values, Evaluate(*expr.operand())); if (IsNullDatum(operand_values)) { return Datum(false); } @@ -1255,14 +1255,42 @@ struct TreeEvaluator::Impl { } Result operator()(const ComparisonExpression& expr) const { - ARROW_ASSIGN_OR_RAISE(auto lhs, Evaluate(*expr.left_operand())); - ARROW_ASSIGN_OR_RAISE(auto rhs, Evaluate(*expr.right_operand())); + ARROW_ASSIGN_OR_RAISE(Datum lhs, Evaluate(*expr.left_operand())); + ARROW_ASSIGN_OR_RAISE(Datum rhs, Evaluate(*expr.right_operand())); if (IsNullDatum(lhs) || IsNullDatum(rhs)) { return Datum(std::make_shared()); } - DCHECK(lhs.is_array()); + if (lhs.type()->id() == Type::DICTIONARY && rhs.type()->id() == Type::DICTIONARY) { + if (lhs.is_array() && rhs.is_array()) { + // decode dictionary arrays + for (Datum* arg : {&lhs, &rhs}) { + auto dict = checked_pointer_cast(arg->make_array()); + ARROW_ASSIGN_OR_RAISE(*arg, compute::Take(dict->dictionary(), dict->indices(), + compute::TakeOptions::Defaults())); + } + } else if (lhs.is_array() || rhs.is_array()) { + auto dict = checked_pointer_cast( + (lhs.is_array() ? lhs : rhs).make_array()); + + ARROW_ASSIGN_OR_RAISE(auto scalar, checked_cast( + *(lhs.is_scalar() ? lhs : rhs).scalar()) + .GetEncodedValue()); + if (lhs.is_array()) { + lhs = dict->dictionary(); + rhs = std::move(scalar); + } else { + lhs = std::move(scalar); + rhs = dict->dictionary(); + } + ARROW_ASSIGN_OR_RAISE( + Datum out_dict, + compute::Compare(lhs, rhs, compute::CompareOptions(expr.op()), &ctx_)); + + return compute::Take(out_dict, dict->indices(), compute::TakeOptions::Defaults()); + } + } return compute::Compare(lhs, rhs, compute::CompareOptions(expr.op()), &ctx_); } diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index cab6f700c34..5f348f763be 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -826,6 +826,24 @@ def test_fragments_parquet_row_groups(tempdir): assert len(result) == 1 +@pytest.mark.pandas +@pytest.mark.parquet +def test_fragments_parquet_row_groups_dictionary(tempdir): + import pandas as pd + + df = pd.DataFrame(dict(col1=['a', 'b'], col2=[1, 2])) + df['col1'] = df['col1'].astype("category") + + import pyarrow.parquet as pq + pq.write_table(pa.table(df), tempdir / "test_filter_dictionary.parquet") + + import pyarrow.dataset as ds + dataset = ds.dataset(tempdir / 'test_filter_dictionary.parquet') + result = dataset.to_table(filter=ds.field("col1") == "a") + + assert (df.iloc[0] == result.to_pandas()).all().all() + + @pytest.mark.pandas @pytest.mark.parquet def test_fragments_parquet_ensure_metadata(tempdir, open_logging_fs):