diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc index f5ab46ac603..0082d48112d 100644 --- a/cpp/src/arrow/compute/api_vector.cc +++ b/cpp/src/arrow/compute/api_vector.cc @@ -74,8 +74,9 @@ Result> Unique(const Datum& value, ExecContext* ctx) { return result.make_array(); } -Result DictionaryEncode(const Datum& value, ExecContext* ctx) { - return CallFunction("dictionary_encode", {value}, ctx); +Result DictionaryEncode(const Datum& value, const DictionaryEncodeOptions& options, + ExecContext* ctx) { + return CallFunction("dictionary_encode", {value}, &options, ctx); } const char kValuesFieldName[] = "values"; diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index 9e9cad9e5d9..d67568e1567 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -63,6 +63,24 @@ enum class SortOrder { Descending, }; +/// \brief Options for the dictionary encode function +struct DictionaryEncodeOptions : public FunctionOptions { + /// Configure how null values will be encoded + enum NullEncodingBehavior { + /// the null value will be added to the dictionary with a proper index + ENCODE, + /// the null value will be masked in the indices array + MASK + }; + + explicit DictionaryEncodeOptions(NullEncodingBehavior null_encoding = MASK) + : null_encoding_behavior(null_encoding) {} + + static DictionaryEncodeOptions Defaults() { return DictionaryEncodeOptions(); } + + NullEncodingBehavior null_encoding_behavior = MASK; +}; + /// \brief One sort key for PartitionNthIndices (TODO) and SortIndices struct ARROW_EXPORT SortKey { explicit SortKey(std::string name, SortOrder order = SortOrder::Ascending) @@ -289,14 +307,29 @@ Result> ValueCounts(const Datum& value, ExecContext* ctx = NULLPTR); /// \brief Dictionary-encode values in an array-like object +/// +/// Any nulls encountered in the dictionary will be handled according to the +/// specified null encoding behavior. +/// +/// For example, given values ["a", "b", null, "a", null] the output will be +/// (null_encoding == ENCODE) Indices: [0, 1, 2, 0, 2] / Dict: ["a", "b", null] +/// (null_encoding == MASK) Indices: [0, 1, null, 0, null] / Dict: ["a", "b"] +/// +/// If the input is already dictionary encoded this function is a no-op unless +/// it needs to modify the null_encoding (TODO) +/// /// \param[in] data array-like input /// \param[in] ctx the function execution context, optional +/// \param[in] options configures null encoding behavior /// \return result with same shape and type as input /// /// \since 1.0.0 /// \note API not yet finalized ARROW_EXPORT -Result DictionaryEncode(const Datum& data, ExecContext* ctx = NULLPTR); +Result DictionaryEncode( + const Datum& data, + const DictionaryEncodeOptions& options = DictionaryEncodeOptions::Defaults(), + ExecContext* ctx = NULLPTR); // ---------------------------------------------------------------------- // Deprecated functions diff --git a/cpp/src/arrow/compute/kernels/vector_hash.cc b/cpp/src/arrow/compute/kernels/vector_hash.cc index 34d18c24a0c..de4d3ee3022 100644 --- a/cpp/src/arrow/compute/kernels/vector_hash.cc +++ b/cpp/src/arrow/compute/kernels/vector_hash.cc @@ -58,7 +58,10 @@ class UniqueAction final : public ActionBase { using ActionBase::ActionBase; static constexpr bool with_error_status = false; - static constexpr bool with_memo_visit_null = true; + + UniqueAction(const std::shared_ptr& type, const FunctionOptions* options, + MemoryPool* pool) + : ActionBase(type, pool) {} Status Reset() { return Status::OK(); } @@ -76,6 +79,8 @@ class UniqueAction final : public ActionBase { template void ObserveNotFound(Index index) {} + bool ShouldEncodeNulls() { return true; } + Status Flush(Datum* out) { return Status::OK(); } Status FlushFinal(Datum* out) { return Status::OK(); } @@ -89,9 +94,9 @@ class ValueCountsAction final : ActionBase { using ActionBase::ActionBase; static constexpr bool with_error_status = true; - static constexpr bool with_memo_visit_null = true; - ValueCountsAction(const std::shared_ptr& type, MemoryPool* pool) + ValueCountsAction(const std::shared_ptr& type, const FunctionOptions* options, + MemoryPool* pool) : ActionBase(type, pool), count_builder_(pool) {} Status Reserve(const int64_t length) { @@ -147,6 +152,8 @@ class ValueCountsAction final : ActionBase { } } + bool ShouldEncodeNulls() const { return true; } + private: Int64Builder count_builder_; }; @@ -159,10 +166,14 @@ class DictEncodeAction final : public ActionBase { using ActionBase::ActionBase; static constexpr bool with_error_status = false; - static constexpr bool with_memo_visit_null = false; - DictEncodeAction(const std::shared_ptr& type, MemoryPool* pool) - : ActionBase(type, pool), indices_builder_(pool) {} + DictEncodeAction(const std::shared_ptr& type, const FunctionOptions* options, + MemoryPool* pool) + : ActionBase(type, pool), indices_builder_(pool) { + if (auto options_ptr = static_cast(options)) { + encode_options_ = *options_ptr; + } + } Status Reset() { indices_builder_.Reset(); @@ -173,12 +184,16 @@ class DictEncodeAction final : public ActionBase { template void ObserveNullFound(Index index) { - indices_builder_.UnsafeAppendNull(); + if (encode_options_.null_encoding_behavior == DictionaryEncodeOptions::MASK) { + indices_builder_.UnsafeAppendNull(); + } else { + indices_builder_.UnsafeAppend(index); + } } template void ObserveNullNotFound(Index index) { - indices_builder_.UnsafeAppendNull(); + ObserveNullFound(index); } template @@ -191,6 +206,10 @@ class DictEncodeAction final : public ActionBase { ObserveFound(index); } + bool ShouldEncodeNulls() { + return encode_options_.null_encoding_behavior == DictionaryEncodeOptions::ENCODE; + } + Status Flush(Datum* out) { std::shared_ptr result; RETURN_NOT_OK(indices_builder_.FinishInternal(&result)); @@ -202,10 +221,14 @@ class DictEncodeAction final : public ActionBase { private: Int32Builder indices_builder_; + DictionaryEncodeOptions encode_options_; }; class HashKernel : public KernelState { public: + HashKernel() : options_(nullptr) {} + explicit HashKernel(const FunctionOptions* options) : options_(options) {} + // Reset for another run. virtual Status Reset() = 0; @@ -229,6 +252,7 @@ class HashKernel : public KernelState { virtual Status Append(const ArrayData& arr) = 0; protected: + const FunctionOptions* options_; std::mutex lock_; }; @@ -237,12 +261,12 @@ class HashKernel : public KernelState { // (NullType has a separate implementation) template + bool with_error_status = Action::with_error_status> class RegularHashKernel : public HashKernel { public: - RegularHashKernel(const std::shared_ptr& type, MemoryPool* pool) - : pool_(pool), type_(type), action_(type, pool) {} + RegularHashKernel(const std::shared_ptr& type, const FunctionOptions* options, + MemoryPool* pool) + : HashKernel(options), pool_(pool), type_(type), action_(type, options, pool) {} Status Reset() override { memo_table_.reset(new MemoTable(pool_, 0)); @@ -282,7 +306,7 @@ class RegularHashKernel : public HashKernel { &unused_memo_index); }, [this]() { - if (with_memo_visit_null) { + if (action_.ShouldEncodeNulls()) { auto on_found = [this](int32_t memo_index) { action_.ObserveNullFound(memo_index); }; @@ -318,16 +342,14 @@ class RegularHashKernel : public HashKernel { [this]() { // Null Status s = Status::OK(); - if (with_memo_visit_null) { - auto on_found = [this](int32_t memo_index) { - action_.ObserveNullFound(memo_index); - }; - auto on_not_found = [this, &s](int32_t memo_index) { - action_.ObserveNullNotFound(memo_index, &s); - }; + auto on_found = [this](int32_t memo_index) { + action_.ObserveNullFound(memo_index); + }; + auto on_not_found = [this, &s](int32_t memo_index) { + action_.ObserveNullNotFound(memo_index, &s); + }; + if (action_.ShouldEncodeNulls()) { memo_table_->GetOrInsertNull(std::move(on_found), std::move(on_not_found)); - } else { - action_.ObserveNullNotFound(-1); } return s; }); @@ -345,18 +367,23 @@ class RegularHashKernel : public HashKernel { // ---------------------------------------------------------------------- // Hash kernel implementation for nulls -template +template class NullHashKernel : public HashKernel { public: - NullHashKernel(const std::shared_ptr& type, MemoryPool* pool) - : pool_(pool), type_(type), action_(type, pool) {} + NullHashKernel(const std::shared_ptr& type, const FunctionOptions* options, + MemoryPool* pool) + : pool_(pool), type_(type), action_(type, options, pool) {} Status Reset() override { return action_.Reset(); } - Status Append(const ArrayData& arr) override { + Status Append(const ArrayData& arr) override { return DoAppend(arr); } + + template + enable_if_t DoAppend(const ArrayData& arr) { RETURN_NOT_OK(action_.Reserve(arr.length)); for (int64_t i = 0; i < arr.length; ++i) { if (i == 0) { + seen_null_ = true; action_.ObserveNullNotFound(0); } else { action_.ObserveNullFound(0); @@ -365,12 +392,31 @@ class NullHashKernel : public HashKernel { return Status::OK(); } + template + enable_if_t DoAppend(const ArrayData& arr) { + Status s = Status::OK(); + RETURN_NOT_OK(action_.Reserve(arr.length)); + for (int64_t i = 0; i < arr.length; ++i) { + if (seen_null_ == false && i == 0) { + seen_null_ = true; + action_.ObserveNullNotFound(0, &s); + } else { + action_.ObserveNullFound(0); + } + } + return s; + } + Status Flush(Datum* out) override { return action_.Flush(out); } Status FlushFinal(Datum* out) override { return action_.FlushFinal(out); } Status GetDictionary(std::shared_ptr* out) override { - // TODO(wesm): handle null being a valid dictionary value - auto null_array = std::make_shared(0); + std::shared_ptr null_array; + if (seen_null_) { + null_array = std::make_shared(1); + } else { + null_array = std::make_shared(0); + } *out = null_array->data(); return Status::OK(); } @@ -380,6 +426,7 @@ class NullHashKernel : public HashKernel { protected: MemoryPool* pool_; std::shared_ptr type_; + bool seen_null_ = false; Action action_; }; @@ -451,8 +498,8 @@ struct HashKernelTraits> { template std::unique_ptr HashInitImpl(KernelContext* ctx, const KernelInitArgs& args) { using HashKernelType = typename HashKernelTraits::HashKernel; - auto result = ::arrow::internal::make_unique(args.inputs[0].type, - ctx->memory_pool()); + auto result = ::arrow::internal::make_unique( + args.inputs[0].type, args.options, ctx->memory_pool()); ctx->SetStatus(result->Reset()); return std::move(result); } @@ -507,6 +554,8 @@ KernelInit GetHashInit(Type::type type_id) { } } +using DictionaryEncodeState = OptionsWrapper; + template std::unique_ptr DictionaryHashInit(KernelContext* ctx, const KernelInitArgs& args) { @@ -639,9 +688,11 @@ const FunctionDoc value_counts_doc( "Nulls in the input are ignored."), {"array"}); +const auto kDefaultDictionaryEncodeOptions = DictionaryEncodeOptions::Defaults(); const FunctionDoc dictionary_encode_doc( "Dictionary-encode array", - ("Return a dictionary-encoded version of the input array."), {"array"}); + ("Return a dictionary-encoded version of the input array."), {"array"}, + "DictionaryEncodeOptions"); } // namespace @@ -691,7 +742,8 @@ void RegisterVectorHash(FunctionRegistry* registry) { // Unique and ValueCounts output unchunked arrays base.output_chunked = true; auto dict_encode = std::make_shared("dictionary_encode", Arity::Unary(), - &dictionary_encode_doc); + &dictionary_encode_doc, + &kDefaultDictionaryEncodeOptions); AddHashKernels(dict_encode.get(), base, OutputType(DictEncodeOutput)); // Calling dictionary_encode on dictionary input not supported, but if it diff --git a/cpp/src/arrow/compute/kernels/vector_hash_test.cc b/cpp/src/arrow/compute/kernels/vector_hash_test.cc index e9ae4a64d97..179792e2141 100644 --- a/cpp/src/arrow/compute/kernels/vector_hash_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_hash_test.cc @@ -305,6 +305,11 @@ TEST_F(TestHashKernel, ValueCountsBoolean) { ArrayFromJSON(boolean(), "[false]"), ArrayFromJSON(int64(), "[2]")); } +TEST_F(TestHashKernel, ValueCountsNull) { + CheckValueCounts(ArrayFromJSON(null(), "[null, null, null]"), + ArrayFromJSON(null(), "[null]"), ArrayFromJSON(int64(), "[3]")); +} + TEST_F(TestHashKernel, DictEncodeBoolean) { CheckDictEncode(boolean(), {true, true, false, true, false}, {true, false, true, true, true}, {true, false}, {}, @@ -542,6 +547,12 @@ TEST_F(TestHashKernel, UniqueDecimal) { {true, false, true, true}, expected, {1, 0, 1}); } +TEST_F(TestHashKernel, UniqueNull) { + CheckUnique(null(), {nullptr, nullptr}, {false, true}, + {nullptr}, {false}); + CheckUnique(null(), {}, {}, {}, {}); +} + TEST_F(TestHashKernel, ValueCountsDecimal) { std::vector values{12, 12, 11, 12}; std::vector expected{12, 0, 11}; @@ -586,6 +597,33 @@ TEST_F(TestHashKernel, DictionaryUniqueAndValueCounts) { auto different_dictionaries = *ChunkedArray::Make({input, input2}); ASSERT_RAISES(Invalid, Unique(different_dictionaries)); ASSERT_RAISES(Invalid, ValueCounts(different_dictionaries)); + + // Dictionary with encoded nulls + auto dict_with_null = ArrayFromJSON(int64(), "[10, null, 30, 40]"); + input = std::make_shared(dict_ty, indices, dict_with_null); + ex_uniques = std::make_shared(dict_ty, ex_indices, dict_with_null); + CheckUnique(input, ex_uniques); + + CheckValueCounts(input, ex_uniques, ex_counts); + + // Dictionary with masked nulls + auto indices_with_null = + ArrayFromJSON(index_ty, "[3, 0, 0, 0, null, null, 3, 0, null, 3, 0, null]"); + auto ex_indices_with_null = ArrayFromJSON(index_ty, "[3, 0, null]"); + ex_uniques = std::make_shared(dict_ty, ex_indices_with_null, dict); + input = std::make_shared(dict_ty, indices_with_null, dict); + CheckUnique(input, ex_uniques); + + CheckValueCounts(input, ex_uniques, ex_counts); + + // Dictionary with encoded AND masked nulls + auto some_indices_with_null = + ArrayFromJSON(index_ty, "[3, 0, 0, 0, 1, 1, 3, 0, null, 3, 0, null]"); + ex_uniques = + std::make_shared(dict_ty, ex_indices_with_null, dict_with_null); + input = std::make_shared(dict_ty, indices_with_null, dict_with_null); + CheckUnique(input, ex_uniques); + CheckValueCounts(input, ex_uniques, ex_counts); } } @@ -656,6 +694,33 @@ TEST_F(TestHashKernel, ZeroLengthDictionaryEncode) { ASSERT_OK(dict_result.ValidateFull()); } +TEST_F(TestHashKernel, NullEncodingSchemes) { + auto values = ArrayFromJSON(uint8(), "[1, 1, null, 2, null]"); + + // Masking should put null in the indices array + auto expected_mask_indices = ArrayFromJSON(int32(), "[0, 0, null, 1, null]"); + auto expected_mask_dictionary = ArrayFromJSON(uint8(), "[1, 2]"); + auto dictionary_type = dictionary(int32(), uint8()); + std::shared_ptr expected = std::make_shared( + dictionary_type, expected_mask_indices, expected_mask_dictionary); + + ASSERT_OK_AND_ASSIGN(Datum datum_result, DictionaryEncode(values)); + std::shared_ptr result = datum_result.make_array(); + AssertArraysEqual(*expected, *result); + + // Encoding should put null in the dictionary + auto expected_encoded_indices = ArrayFromJSON(int32(), "[0, 0, 1, 2, 1]"); + auto expected_encoded_dict = ArrayFromJSON(uint8(), "[1, null, 2]"); + expected = std::make_shared(dictionary_type, expected_encoded_indices, + expected_encoded_dict); + + auto options = DictionaryEncodeOptions::Defaults(); + options.null_encoding_behavior = DictionaryEncodeOptions::ENCODE; + ASSERT_OK_AND_ASSIGN(datum_result, DictionaryEncode(values, options)); + result = datum_result.make_array(); + AssertArraysEqual(*expected, *result); +} + TEST_F(TestHashKernel, ChunkedArrayZeroChunk) { // ARROW-6857 auto chunked_array = std::make_shared(ArrayVector{}, utf8()); diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index 56339430ee9..5ddb270451a 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -95,6 +95,8 @@ namespace { std::string PrintDatum(const Datum& datum) { if (datum.is_scalar()) { + if (!datum.scalar()->is_valid) return "null"; + switch (datum.type()->id()) { case Type::STRING: case Type::LARGE_STRING: @@ -110,6 +112,7 @@ std::string PrintDatum(const Datum& datum) { default: break; } + return datum.scalar()->ToString(); } return datum.ToString(); @@ -698,16 +701,25 @@ Status ExtractKnownFieldValuesImpl( return !(ref && lit); } + if (call->function_name == "is_null") { + auto ref = call->arguments[0].field_ref(); + return !ref; + } + return true; }); for (auto it = unconsumed_end; it != conjunction_members->end(); ++it) { auto call = CallNotNull(*it); - auto ref = call->arguments[0].field_ref(); - auto lit = call->arguments[1].literal(); - - known_values->emplace(*ref, *lit); + if (call->function_name == "equal") { + auto ref = call->arguments[0].field_ref(); + auto lit = call->arguments[1].literal(); + known_values->emplace(*ref, *lit); + } else if (call->function_name == "is_null") { + auto ref = call->arguments[0].field_ref(); + known_values->emplace(*ref, Datum(std::make_shared())); + } } conjunction_members->erase(unconsumed_end, conjunction_members->end()); @@ -756,7 +768,7 @@ Result ReplaceFieldsWithKnownValues( DictionaryScalar::Make(std::move(index), std::move(dictionary))); } } - ARROW_ASSIGN_OR_RAISE(lit, compute::Cast(it->second, expr.type())); + ARROW_ASSIGN_OR_RAISE(lit, compute::Cast(lit, expr.type())); return literal(std::move(lit)); } } @@ -1222,6 +1234,10 @@ Expression greater_equal(Expression lhs, Expression rhs) { return call("greater_equal", {std::move(lhs), std::move(rhs)}); } +Expression is_null(Expression lhs) { return call("is_null", {std::move(lhs)}); } + +Expression is_valid(Expression lhs) { return call("is_valid", {std::move(lhs)}); } + Expression and_(Expression lhs, Expression rhs) { return call("and_kleene", {std::move(lhs), std::move(rhs)}); } diff --git a/cpp/src/arrow/dataset/expression.h b/cpp/src/arrow/dataset/expression.h index 13c714b2d72..8bdcb4a0ffa 100644 --- a/cpp/src/arrow/dataset/expression.h +++ b/cpp/src/arrow/dataset/expression.h @@ -236,6 +236,10 @@ ARROW_DS_EXPORT Expression greater(Expression lhs, Expression rhs); ARROW_DS_EXPORT Expression greater_equal(Expression lhs, Expression rhs); +ARROW_DS_EXPORT Expression is_null(Expression lhs); + +ARROW_DS_EXPORT Expression is_valid(Expression lhs); + ARROW_DS_EXPORT Expression and_(Expression lhs, Expression rhs); ARROW_DS_EXPORT Expression and_(const std::vector&); ARROW_DS_EXPORT Expression or_(Expression lhs, Expression rhs); diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index 2f0110255ec..c837c5be893 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -240,6 +240,10 @@ TEST(Expression, Equality) { call("cast", {field_ref("a")}, compute::CastOptions::Unsafe(int32()))); } +Expression null_literal(const std::shared_ptr& type) { + return Expression(MakeNullScalar(type)); +} + TEST(Expression, Hash) { std::unordered_set set; @@ -250,6 +254,9 @@ TEST(Expression, Hash) { EXPECT_FALSE(set.emplace(literal(1)).second) << "already inserted"; EXPECT_TRUE(set.emplace(literal(3)).second); + EXPECT_TRUE(set.emplace(null_literal(int32())).second); + EXPECT_FALSE(set.emplace(null_literal(int32())).second) << "already inserted"; + EXPECT_TRUE(set.emplace(null_literal(float32())).second); // NB: no validation on construction; we couldn't execute // add with zero arguments EXPECT_TRUE(set.emplace(call("add", {})).second); @@ -258,7 +265,7 @@ TEST(Expression, Hash) { // NB: unbound expressions don't check for availability in any registry EXPECT_TRUE(set.emplace(call("widgetify", {})).second); - EXPECT_EQ(set.size(), 6); + EXPECT_EQ(set.size(), 8); } TEST(Expression, IsScalarExpression) { @@ -603,6 +610,8 @@ TEST(Expression, FoldConstants) { // call against literals (3 + 2 == 5) ExpectFoldsTo(call("add", {literal(3), literal(2)}), literal(5)); + ExpectFoldsTo(call("equal", {literal(3), literal(3)}), literal(true)); + // call against literal and field_ref ExpectFoldsTo(call("add", {literal(3), field_ref("i32")}), call("add", {literal(3), field_ref("i32")})); @@ -722,7 +731,7 @@ TEST(Expression, ExtractKnownFieldValues) { TEST(Expression, ReplaceFieldsWithKnownValues) { auto ExpectReplacesTo = [](Expression expr, - std::unordered_map known_values, + const std::unordered_map& known_values, Expression unbound_expected) { ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*kBoringSchema)); ASSERT_OK_AND_ASSIGN(auto expected, unbound_expected.Bind(*kBoringSchema)); @@ -765,6 +774,19 @@ TEST(Expression, ReplaceFieldsWithKnownValues) { }), literal(2), })); + + std::unordered_map i32_valid_str_null{ + {"i32", Datum(3)}, {"str", MakeNullScalar(utf8())}}; + + ExpectReplacesTo(is_null(field_ref("i32")), i32_valid_str_null, is_null(literal(3))); + + ExpectReplacesTo(is_valid(field_ref("i32")), i32_valid_str_null, is_valid(literal(3))); + + ExpectReplacesTo(is_null(field_ref("str")), i32_valid_str_null, + is_null(null_literal(utf8()))); + + ExpectReplacesTo(is_valid(field_ref("str")), i32_valid_str_null, + is_valid(null_literal(utf8()))); } struct { @@ -1013,6 +1035,22 @@ TEST(Expression, SimplifyWithGuarantee) { Simplify{greater(field_ref("dict_i32"), literal(int64_t(1)))} .WithGuarantee(equal(field_ref("dict_i32"), literal(0))) .Expect(false); + + Simplify{equal(field_ref("i32"), literal(7))} + .WithGuarantee(equal(field_ref("i32"), literal(7))) + .Expect(literal(true)); + + Simplify{equal(field_ref("i32"), literal(7))} + .WithGuarantee(not_(equal(field_ref("i32"), literal(7)))) + .Expect(equal(field_ref("i32"), literal(7))); + + Simplify{is_null(field_ref("i32"))} + .WithGuarantee(is_null(field_ref("i32"))) + .Expect(literal(true)); + + Simplify{is_valid(field_ref("i32"))} + .WithGuarantee(is_valid(field_ref("i32"))) + .Expect(is_valid(field_ref("i32"))); } TEST(Expression, SimplifyThenExecute) { diff --git a/cpp/src/arrow/dataset/partition.cc b/cpp/src/arrow/dataset/partition.cc index d6a3723d055..522dbbeb5d2 100644 --- a/cpp/src/arrow/dataset/partition.cc +++ b/cpp/src/arrow/dataset/partition.cc @@ -92,7 +92,11 @@ inline Expression ConjunctionFromGroupingRow(Scalar* row) { std::vector equality_expressions(values->size()); for (size_t i = 0; i < values->size(); ++i) { const std::string& name = row->type->field(static_cast(i))->name(); - equality_expressions[i] = equal(field_ref(name), literal(std::move(values->at(i)))); + if (values->at(i)->is_valid) { + equality_expressions[i] = equal(field_ref(name), literal(std::move(values->at(i)))); + } else { + equality_expressions[i] = is_null(field_ref(name)); + } } return and_(std::move(equality_expressions)); } @@ -147,7 +151,9 @@ Result KeyValuePartitioning::ConvertKey(const Key& key) const { std::shared_ptr converted; - if (field->type()->id() == Type::DICTIONARY) { + if (!key.value.has_value()) { + return is_null(field_ref(field->name())); + } else if (field->type()->id() == Type::DICTIONARY) { if (dictionaries_.empty() || dictionaries_[field_index] == nullptr) { return Status::Invalid("No dictionary provided for dictionary field ", field->ToString()); @@ -164,16 +170,16 @@ Result KeyValuePartitioning::ConvertKey(const Key& key) const { } // look up the partition value in the dictionary - ARROW_ASSIGN_OR_RAISE(converted, Scalar::Parse(value.dictionary->type(), key.value)); + ARROW_ASSIGN_OR_RAISE(converted, Scalar::Parse(value.dictionary->type(), *key.value)); ARROW_ASSIGN_OR_RAISE(auto index, compute::IndexIn(converted, value.dictionary)); value.index = index.scalar(); if (!value.index->is_valid) { return Status::Invalid("Dictionary supplied for field ", field->ToString(), - " does not contain '", key.value, "'"); + " does not contain '", *key.value, "'"); } converted = std::make_shared(std::move(value), field->type()); } else { - ARROW_ASSIGN_OR_RAISE(converted, Scalar::Parse(field->type(), key.value)); + ARROW_ASSIGN_OR_RAISE(converted, Scalar::Parse(field->type(), *key.value)); } return equal(field_ref(field->name()), literal(std::move(converted))); @@ -207,8 +213,18 @@ Result KeyValuePartitioning::Format(const Expression& expr) const { const auto& field = schema_->field(match[0]); if (!value->type->Equals(field->type())) { - return Status::TypeError("scalar ", value->ToString(), " (of type ", *value->type, - ") is invalid for ", field->ToString()); + if (value->is_valid) { + auto maybe_converted = compute::Cast(value, field->type()); + if (!maybe_converted.ok()) { + return Status::TypeError("Error converting scalar ", value->ToString(), + " (of type ", *value->type, + ") to a partition key for ", field->ToString(), ": ", + maybe_converted.status().message()); + } + value = maybe_converted->scalar(); + } else { + value = MakeNullScalar(field->type()); + } } if (value->type->id() == Type::DICTIONARY) { @@ -252,7 +268,7 @@ Result DirectoryPartitioning::FormatValues( std::vector segments(static_cast(schema_->num_fields())); for (int i = 0; i < schema_->num_fields(); ++i) { - if (values[i] != nullptr) { + if (values[i] != nullptr && values[i]->is_valid) { segments[i] = values[i]->ToString(); continue; } @@ -287,8 +303,13 @@ class KeyValuePartitioningFactory : public PartitioningFactory { return it_inserted.first->second; } - Status InsertRepr(const std::string& name, util::string_view repr) { - return InsertRepr(GetOrInsertField(name), repr); + Status InsertRepr(const std::string& name, util::optional repr) { + auto field_index = GetOrInsertField(name); + if (repr.has_value()) { + return InsertRepr(field_index, *repr); + } else { + return Status::OK(); + } } Status InsertRepr(int index, util::string_view repr) { @@ -309,7 +330,7 @@ class KeyValuePartitioningFactory : public PartitioningFactory { RETURN_NOT_OK(repr_memos_[index]->GetArrayData(0, &reprs)); if (reprs->length == 0) { - return Status::Invalid("No segments were available for field '", name, + return Status::Invalid("No non-null segments were available for field '", name, "'; couldn't infer type"); } @@ -410,13 +431,19 @@ std::shared_ptr DirectoryPartitioning::MakeFactory( } util::optional HivePartitioning::ParseKey( - const std::string& segment) { + const std::string& segment, const std::string& null_fallback) { auto name_end = string_view(segment).find_first_of('='); + // Not round-trippable if (name_end == string_view::npos) { return util::nullopt; } - return Key{segment.substr(0, name_end), segment.substr(name_end + 1)}; + auto name = segment.substr(0, name_end); + auto value = segment.substr(name_end + 1); + if (value == null_fallback) { + return Key{name, util::nullopt}; + } + return Key{name, value}; } std::vector HivePartitioning::ParseKeys( @@ -424,7 +451,7 @@ std::vector HivePartitioning::ParseKeys( std::vector keys; for (const auto& segment : fs::internal::SplitAbstractPath(path)) { - if (auto key = ParseKey(segment)) { + if (auto key = ParseKey(segment, null_fallback_)) { keys.push_back(std::move(*key)); } } @@ -439,11 +466,11 @@ Result HivePartitioning::FormatValues(const ScalarVector& values) c const std::string& name = schema_->field(i)->name(); if (values[i] == nullptr) { - if (!NextValid(values, i)) break; - + segments[i] = ""; + } else if (!values[i]->is_valid) { // If no key is available just provide a placeholder segment to maintain the // field_index <-> path nesting relation - segments[i] = name; + segments[i] = name + "=" + null_fallback_; } else { segments[i] = name + "=" + values[i]->ToString(); } @@ -454,8 +481,8 @@ Result HivePartitioning::FormatValues(const ScalarVector& values) c class HivePartitioningFactory : public KeyValuePartitioningFactory { public: - explicit HivePartitioningFactory(PartitioningFactoryOptions options) - : KeyValuePartitioningFactory(options) {} + explicit HivePartitioningFactory(HivePartitioningFactoryOptions options) + : KeyValuePartitioningFactory(options), null_fallback_(options.null_fallback) {} std::string type_name() const override { return "hive"; } @@ -463,7 +490,7 @@ class HivePartitioningFactory : public KeyValuePartitioningFactory { const std::vector& paths) override { for (auto path : paths) { for (auto&& segment : fs::internal::SplitAbstractPath(path)) { - if (auto key = HivePartitioning::ParseKey(segment)) { + if (auto key = HivePartitioning::ParseKey(segment, null_fallback_)) { RETURN_NOT_OK(InsertRepr(key->name, key->value)); } } @@ -486,16 +513,18 @@ class HivePartitioningFactory : public KeyValuePartitioningFactory { // drop fields which aren't in field_names_ auto out_schema = SchemaFromColumnNames(schema, field_names_); - return std::make_shared(std::move(out_schema), dictionaries_); + return std::make_shared(std::move(out_schema), dictionaries_, + null_fallback_); } } private: + const std::string null_fallback_; std::vector field_names_; }; std::shared_ptr HivePartitioning::MakeFactory( - PartitioningFactoryOptions options) { + HivePartitioningFactoryOptions options) { return std::shared_ptr(new HivePartitioningFactory(options)); } @@ -578,10 +607,6 @@ class StructDictionary { Encoded out{nullptr, std::make_shared()}; for (const auto& column : columns) { - if (column->null_count() != 0) { - return Status::NotImplemented("Grouping on a field with nulls"); - } - RETURN_NOT_OK(out.dictionary->AddOne(column, &out.indices)); } @@ -625,8 +650,27 @@ class StructDictionary { private: Status AddOne(Datum column, std::shared_ptr* fused_indices) { + if (column.type()->id() == Type::DICTIONARY) { + if (column.null_count() != 0) { + // TODO(ARROW-11732) Optimize this by allowign DictionaryEncode to transfer a + // null-masked dictionary to a null-encoded dictionary. At the moment we decode + // and then encode causing one extra copy, and a potentially expansive decoding + // copy at that. + ARROW_ASSIGN_OR_RAISE( + auto decoded_dictionary, + compute::Cast( + column, + std::static_pointer_cast(column.type())->value_type(), + compute::CastOptions())); + column = decoded_dictionary; + } + } if (column.type()->id() != Type::DICTIONARY) { - ARROW_ASSIGN_OR_RAISE(column, compute::DictionaryEncode(std::move(column))); + compute::DictionaryEncodeOptions options; + options.null_encoding_behavior = + compute::DictionaryEncodeOptions::NullEncodingBehavior::ENCODE; + ARROW_ASSIGN_OR_RAISE(column, + compute::DictionaryEncode(std::move(column), options)); } auto dict_column = column.array_as(); diff --git a/cpp/src/arrow/dataset/partition.h b/cpp/src/arrow/dataset/partition.h index 944434e64f7..42e1b4c4097 100644 --- a/cpp/src/arrow/dataset/partition.h +++ b/cpp/src/arrow/dataset/partition.h @@ -92,6 +92,11 @@ struct PartitioningFactoryOptions { bool infer_dictionary = false; }; +struct HivePartitioningFactoryOptions : PartitioningFactoryOptions { + /// The hive partitioning scheme maps null to a hard coded fallback string. + std::string null_fallback; +}; + /// \brief PartitioningFactory provides creation of a partitioning when the /// specific schema must be inferred from available paths (no explicit schema is known). class ARROW_DS_EXPORT PartitioningFactory { @@ -119,7 +124,8 @@ class ARROW_DS_EXPORT KeyValuePartitioning : public Partitioning { /// An unconverted equality expression consisting of a field name and the representation /// of a scalar value struct Key { - std::string name, value; + std::string name; + util::optional value; }; static Status SetDefaultValuesFromKeys(const Expression& expr, @@ -175,6 +181,8 @@ class ARROW_DS_EXPORT DirectoryPartitioning : public KeyValuePartitioning { Result FormatValues(const ScalarVector& values) const override; }; +static constexpr char kDefaultHiveNullFallback[] = "__HIVE_DEFAULT_PARTITION__"; + /// \brief Multi-level, directory based partitioning /// originating from Apache Hive with all data files stored in the /// leaf directories. Data is partitioned by static values of a @@ -188,17 +196,22 @@ class ARROW_DS_EXPORT HivePartitioning : public KeyValuePartitioning { public: // If a field in schema is of dictionary type, the corresponding element of dictionaries // must be contain the dictionary of values for that field. - explicit HivePartitioning(std::shared_ptr schema, ArrayVector dictionaries = {}) - : KeyValuePartitioning(std::move(schema), std::move(dictionaries)) {} + explicit HivePartitioning(std::shared_ptr schema, ArrayVector dictionaries = {}, + std::string null_fallback = kDefaultHiveNullFallback) + : KeyValuePartitioning(std::move(schema), std::move(dictionaries)), + null_fallback_(null_fallback) {} std::string type_name() const override { return "hive"; } + std::string null_fallback() const { return null_fallback_; } - static util::optional ParseKey(const std::string& segment); + static util::optional ParseKey(const std::string& segment, + const std::string& null_fallback); static std::shared_ptr MakeFactory( - PartitioningFactoryOptions = {}); + HivePartitioningFactoryOptions = {}); private: + const std::string null_fallback_; std::vector ParseKeys(const std::string& path) const override; Result FormatValues(const ScalarVector& values) const override; diff --git a/cpp/src/arrow/dataset/partition_test.cc b/cpp/src/arrow/dataset/partition_test.cc index 286848d9ae9..75e60f994f0 100644 --- a/cpp/src/arrow/dataset/partition_test.cc +++ b/cpp/src/arrow/dataset/partition_test.cc @@ -27,6 +27,7 @@ #include #include "arrow/compute/api_scalar.h" +#include "arrow/compute/api_vector.h" #include "arrow/dataset/scanner_internal.h" #include "arrow/dataset/test_util.h" #include "arrow/filesystem/path_util.h" @@ -77,6 +78,39 @@ class TestPartitioning : public ::testing::Test { ASSERT_OK_AND_ASSIGN(partitioning_, factory_->Finish(actual)); } + void AssertPartition(const std::shared_ptr partitioning, + const std::shared_ptr full_batch, + const RecordBatchVector& expected_batches, + const std::vector& expected_expressions) { + ASSERT_OK_AND_ASSIGN(auto partition_results, partitioning->Partition(full_batch)); + std::shared_ptr rest = full_batch; + ASSERT_EQ(partition_results.batches.size(), expected_batches.size()); + auto max_index = std::min(partition_results.batches.size(), expected_batches.size()); + for (std::size_t partition_index = 0; partition_index < max_index; + partition_index++) { + std::shared_ptr actual_batch = + partition_results.batches[partition_index]; + AssertBatchesEqual(*expected_batches[partition_index], *actual_batch); + Expression actual_expression = partition_results.expressions[partition_index]; + ASSERT_EQ(expected_expressions[partition_index], actual_expression); + } + } + + void AssertPartition(const std::shared_ptr partitioning, + const std::shared_ptr schema, + const std::string& record_batch_json, + const std::shared_ptr partitioned_schema, + const std::vector& expected_record_batch_strs, + const std::vector& expected_expressions) { + auto record_batch = RecordBatchFromJSON(schema, record_batch_json); + RecordBatchVector expected_batches; + for (const auto& expected_record_batch_str : expected_record_batch_strs) { + expected_batches.push_back( + RecordBatchFromJSON(partitioned_schema, expected_record_batch_str)); + } + AssertPartition(partitioning, record_batch, expected_batches, expected_expressions); + } + void AssertInspectError(const std::vector& paths) { ASSERT_RAISES(Invalid, factory_->Inspect(paths)); } @@ -103,6 +137,30 @@ class TestPartitioning : public ::testing::Test { std::shared_ptr written_schema_; }; +TEST_F(TestPartitioning, Partition) { + auto partition_schema = schema({field("a", int32()), field("b", utf8())}); + auto schema_ = schema({field("a", int32()), field("b", utf8()), field("c", uint32())}); + auto remaining_schema = schema({field("c", uint32())}); + auto partitioning = std::make_shared(partition_schema); + std::string json = R"([{"a": 3, "b": "x", "c": 0}, + {"a": 3, "b": "x", "c": 1}, + {"a": 1, "b": null, "c": 2}, + {"a": null, "b": null, "c": 3}, + {"a": null, "b": "z", "c": 4}, + {"a": null, "b": null, "c": 5} + ])"; + std::vector expected_batches = {R"([{"c": 0}, {"c": 1}])", R"([{"c": 2}])", + R"([{"c": 3}, {"c": 5}])", + R"([{"c": 4}])"}; + std::vector expected_expressions = { + and_(equal(field_ref("a"), literal(3)), equal(field_ref("b"), literal("x"))), + and_(equal(field_ref("a"), literal(1)), is_null(field_ref("b"))), + and_(is_null(field_ref("a")), is_null(field_ref("b"))), + and_(is_null(field_ref("a")), equal(field_ref("b"), literal("z")))}; + AssertPartition(partitioning, schema_, json, remaining_schema, expected_batches, + expected_expressions); +} + TEST_F(TestPartitioning, DirectoryPartitioning) { partitioning_ = std::make_shared( schema({field("alpha", int32()), field("beta", utf8())})); @@ -136,6 +194,10 @@ TEST_F(TestPartitioning, DirectoryPartitioningFormat) { equal(field_ref("alpha"), literal(0))), "0/hello"); AssertFormat(equal(field_ref("alpha"), literal(0)), "0"); + AssertFormat(and_(equal(field_ref("alpha"), literal(0)), is_null(field_ref("beta"))), + "0"); + AssertFormatError( + and_(is_null(field_ref("alpha")), equal(field_ref("beta"), literal("hello")))); AssertFormatError(equal(field_ref("beta"), literal("hello"))); AssertFormat(literal(true), ""); @@ -209,6 +271,8 @@ TEST_F(TestPartitioning, DictionaryInference) { // successful dictionary inference AssertInspect({"/a/0"}, {DictStr("alpha"), DictInt("beta")}); AssertInspect({"/a/0", "/a/1"}, {DictStr("alpha"), DictInt("beta")}); + AssertInspect({"/a/0", "/a"}, {DictStr("alpha"), DictInt("beta")}); + AssertInspect({"/0/a", "/1"}, {DictInt("alpha"), DictStr("beta")}); AssertInspect({"/a/0", "/b/0", "/a/1", "/b/1"}, {DictStr("alpha"), DictInt("beta")}); AssertInspect({"/a/-", "/b/-", "/a/_", "/b/_"}, {DictStr("alpha"), DictStr("beta")}); } @@ -246,13 +310,15 @@ TEST_F(TestPartitioning, DiscoverSchemaSegfault) { TEST_F(TestPartitioning, HivePartitioning) { partitioning_ = std::make_shared( - schema({field("alpha", int32()), field("beta", float32())})); + schema({field("alpha", int32()), field("beta", float32())}), ArrayVector(), "xyz"); AssertParse("/alpha=0/beta=3.25", and_(equal(field_ref("alpha"), literal(0)), equal(field_ref("beta"), literal(3.25f)))); AssertParse("/beta=3.25/alpha=0", and_(equal(field_ref("beta"), literal(3.25f)), equal(field_ref("alpha"), literal(0)))); AssertParse("/alpha=0", equal(field_ref("alpha"), literal(0))); + AssertParse("/alpha=xyz/beta=3.25", and_(is_null(field_ref("alpha")), + equal(field_ref("beta"), literal(3.25f)))); AssertParse("/beta=3.25", equal(field_ref("beta"), literal(3.25f))); AssertParse("", literal(true)); @@ -271,7 +337,7 @@ TEST_F(TestPartitioning, HivePartitioning) { TEST_F(TestPartitioning, HivePartitioningFormat) { partitioning_ = std::make_shared( - schema({field("alpha", int32()), field("beta", float32())})); + schema({field("alpha", int32()), field("beta", float32())}), ArrayVector(), "xyz"); written_schema_ = partitioning_->schema(); @@ -282,9 +348,16 @@ TEST_F(TestPartitioning, HivePartitioningFormat) { equal(field_ref("alpha"), literal(0))), "alpha=0/beta=3.25"); AssertFormat(equal(field_ref("alpha"), literal(0)), "alpha=0"); - AssertFormat(equal(field_ref("beta"), literal(3.25f)), "alpha/beta=3.25"); + AssertFormat(and_(equal(field_ref("alpha"), literal(0)), is_null(field_ref("beta"))), + "alpha=0/beta=xyz"); + AssertFormat( + and_(is_null(field_ref("alpha")), equal(field_ref("beta"), literal(3.25f))), + "alpha=xyz/beta=3.25"); AssertFormat(literal(true), ""); + AssertFormat(and_(is_null(field_ref("alpha")), is_null(field_ref("beta"))), + "alpha=xyz/beta=xyz"); + ASSERT_OK_AND_ASSIGN(written_schema_, written_schema_->AddField(0, field("gamma", utf8()))); AssertFormat(and_({equal(field_ref("gamma"), literal("yo")), @@ -300,7 +373,9 @@ TEST_F(TestPartitioning, HivePartitioningFormat) { } TEST_F(TestPartitioning, DiscoverHiveSchema) { - factory_ = HivePartitioning::MakeFactory(); + auto options = HivePartitioningFactoryOptions(); + options.null_fallback = "xyz"; + factory_ = HivePartitioning::MakeFactory(options); // type is int32 if possible AssertInspect({"/alpha=0/beta=1"}, {Int("alpha"), Int("beta")}); @@ -313,6 +388,12 @@ TEST_F(TestPartitioning, DiscoverHiveSchema) { // (...so ensure your partitions are ordered the same for all paths) AssertInspect({"/alpha=0/beta=1", "/beta=2/alpha=3"}, {Int("alpha"), Int("beta")}); + // Null fallback strings shouldn't interfere with type inference + AssertInspect({"/alpha=xyz/beta=x", "/alpha=7/beta=xyz"}, {Int("alpha"), Str("beta")}); + + // Cannot infer if the only values are null + AssertInspectError({"/alpha=xyz"}); + // If there are too many digits fall back to string AssertInspect({"/alpha=3760212050"}, {Str("alpha")}); @@ -322,8 +403,9 @@ TEST_F(TestPartitioning, DiscoverHiveSchema) { } TEST_F(TestPartitioning, HiveDictionaryInference) { - PartitioningFactoryOptions options; + HivePartitioningFactoryOptions options; options.infer_dictionary = true; + options.null_fallback = "xyz"; factory_ = HivePartitioning::MakeFactory(options); // type is still int32 if possible @@ -335,6 +417,8 @@ TEST_F(TestPartitioning, HiveDictionaryInference) { // successful dictionary inference AssertInspect({"/alpha=a/beta=0"}, {DictStr("alpha"), DictInt("beta")}); AssertInspect({"/alpha=a/beta=0", "/alpha=a/1"}, {DictStr("alpha"), DictInt("beta")}); + AssertInspect({"/alpha=a/beta=0", "/alpha=xyz/beta=xyz"}, + {DictStr("alpha"), DictInt("beta")}); AssertInspect( {"/alpha=a/beta=0", "/alpha=b/beta=0", "/alpha=a/beta=1", "/alpha=b/beta=1"}, {DictStr("alpha"), DictInt("beta")}); @@ -343,8 +427,19 @@ TEST_F(TestPartitioning, HiveDictionaryInference) { {DictStr("alpha"), DictStr("beta")}); } +TEST_F(TestPartitioning, HiveNullFallbackPassedOn) { + HivePartitioningFactoryOptions options; + options.null_fallback = "xyz"; + factory_ = HivePartitioning::MakeFactory(options); + + EXPECT_OK_AND_ASSIGN(auto schema, factory_->Inspect({"/alpha=a/beta=0"})); + EXPECT_OK_AND_ASSIGN(auto partitioning, factory_->Finish(schema)); + ASSERT_EQ("xyz", + std::static_pointer_cast(partitioning)->null_fallback()); +} + TEST_F(TestPartitioning, HiveDictionaryHasUniqueValues) { - PartitioningFactoryOptions options; + HivePartitioningFactoryOptions options; options.infer_dictionary = true; factory_ = HivePartitioning::MakeFactory(options); @@ -369,6 +464,55 @@ TEST_F(TestPartitioning, HiveDictionaryHasUniqueValues) { AssertParseError("/alpha=yosemite"); // not in inspected dictionary } +TEST_F(TestPartitioning, SetDefaultValuesConcrete) { + auto small_schm = schema({field("c", int32())}); + auto schm = schema({field("a", int32()), field("b", utf8())}); + auto full_schm = schema({field("a", int32()), field("b", utf8()), field("c", int32())}); + RecordBatchProjector record_batch_projector(full_schm); + HivePartitioning part(schm); + ARROW_EXPECT_OK(part.SetDefaultValuesFromKeys( + and_(equal(field_ref("a"), literal(10)), is_valid(field_ref("b"))), + &record_batch_projector)); + + auto in_rb = RecordBatchFromJSON(small_schm, R"([{"c": 0}, + {"c": 1}, + {"c": 2}, + {"c": 3} + ])"); + + EXPECT_OK_AND_ASSIGN(auto out_rb, record_batch_projector.Project(*in_rb)); + auto expected_rb = RecordBatchFromJSON(full_schm, R"([{"a": 10, "b": null, "c": 0}, + {"a": 10, "b": null, "c": 1}, + {"a": 10, "b": null, "c": 2}, + {"a": 10, "b": null, "c": 3} + ])"); + AssertBatchesEqual(*expected_rb, *out_rb); +} + +TEST_F(TestPartitioning, SetDefaultValuesNull) { + auto small_schm = schema({field("c", int32())}); + auto schm = schema({field("a", int32()), field("b", utf8())}); + auto full_schm = schema({field("a", int32()), field("b", utf8()), field("c", int32())}); + RecordBatchProjector record_batch_projector(full_schm); + HivePartitioning part(schm); + ARROW_EXPECT_OK(part.SetDefaultValuesFromKeys( + and_(is_null(field_ref("a")), is_null(field_ref("b"))), &record_batch_projector)); + + auto in_rb = RecordBatchFromJSON(small_schm, R"([{"c": 0}, + {"c": 1}, + {"c": 2}, + {"c": 3} + ])"); + + EXPECT_OK_AND_ASSIGN(auto out_rb, record_batch_projector.Project(*in_rb)); + auto expected_rb = RecordBatchFromJSON(full_schm, R"([{"a": null, "b": null, "c": 0}, + {"a": null, "b": null, "c": 1}, + {"a": null, "b": null, "c": 2}, + {"a": null, "b": null, "c": 3} + ])"); + AssertBatchesEqual(*expected_rb, *out_rb); +} + TEST_F(TestPartitioning, EtlThenHive) { FieldVector etl_fields{field("year", int16()), field("month", int8()), field("day", int8()), field("hour", int8())}; @@ -467,13 +611,13 @@ class RangePartitioning : public Partitioning { std::vector ranges; for (auto segment : fs::internal::SplitAbstractPath(path)) { - auto key = HivePartitioning::ParseKey(segment); + auto key = HivePartitioning::ParseKey(segment, ""); if (!key) { return Status::Invalid("can't parse '", segment, "' as a range"); } std::smatch matches; - RETURN_NOT_OK(DoRegex(key->value, &matches)); + RETURN_NOT_OK(DoRegex(*key->value, &matches)); auto& min_cmp = matches[1] == "[" ? greater_equal : greater; std::string min_repr = matches[2]; @@ -600,20 +744,45 @@ TEST(GroupTest, Basics) { } TEST(GroupTest, WithNulls) { - auto has_nulls = checked_pointer_cast( - ArrayFromJSON(struct_({field("a", utf8()), field("b", int32())}), R"([ - {"a": "ex", "b": 0}, - {"a": null, "b": 0}, - {"a": "why", "b": 0}, - {"a": "ex", "b": 1}, - {"a": "why", "b": 0}, - {"a": "ex", "b": 1}, - {"a": "ex", "b": 0}, - {"a": "why", "b": null} - ])")); - ASSERT_RAISES(NotImplemented, MakeGroupings(*has_nulls)); + AssertGrouping({field("a", utf8()), field("b", int32())}, + R"([ + {"a": "ex", "b": 0, "id": 0}, + {"a": null, "b": 0, "id": 1}, + {"a": null, "b": 0, "id": 2}, + {"a": "ex", "b": 1, "id": 3}, + {"a": null, "b": null, "id": 4}, + {"a": "ex", "b": 1, "id": 5}, + {"a": "ex", "b": 0, "id": 6}, + {"a": "why", "b": null, "id": 7} + ])", + R"([ + {"a": "ex", "b": 0, "ids": [0, 6]}, + {"a": null, "b": 0, "ids": [1, 2]}, + {"a": "ex", "b": 1, "ids": [3, 5]}, + {"a": null, "b": null, "ids": [4]}, + {"a": "why", "b": null, "ids": [7]} + ])"); - has_nulls = checked_pointer_cast( + AssertGrouping({field("a", dictionary(int32(), utf8())), field("b", int32())}, + R"([ + {"a": "ex", "b": 0, "id": 0}, + {"a": null, "b": 0, "id": 1}, + {"a": null, "b": 0, "id": 2}, + {"a": "ex", "b": 1, "id": 3}, + {"a": null, "b": null, "id": 4}, + {"a": "ex", "b": 1, "id": 5}, + {"a": "ex", "b": 0, "id": 6}, + {"a": "why", "b": null, "id": 7} + ])", + R"([ + {"a": "ex", "b": 0, "ids": [0, 6]}, + {"a": null, "b": 0, "ids": [1, 2]}, + {"a": "ex", "b": 1, "ids": [3, 5]}, + {"a": null, "b": null, "ids": [4]}, + {"a": "why", "b": null, "ids": [7]} + ])"); + + auto has_nulls = checked_pointer_cast( ArrayFromJSON(struct_({field("a", utf8()), field("b", int32())}), R"([ {"a": "ex", "b": 0}, null, diff --git a/cpp/src/arrow/dataset/projector.cc b/cpp/src/arrow/dataset/projector.cc index 2ba679ce6e7..ba0eb2ddff5 100644 --- a/cpp/src/arrow/dataset/projector.cc +++ b/cpp/src/arrow/dataset/projector.cc @@ -23,6 +23,7 @@ #include #include "arrow/array.h" +#include "arrow/compute/cast.h" #include "arrow/dataset/type_fwd.h" #include "arrow/record_batch.h" #include "arrow/result.h" @@ -88,9 +89,18 @@ Status RecordBatchProjector::SetDefaultValue(FieldRef ref, auto field_type = to_->field(index)->type(); if (!field_type->Equals(scalar->type)) { - return Status::TypeError("field ", to_->field(index)->ToString(), - " cannot be materialized from scalar of type ", - *scalar->type); + if (scalar->is_valid) { + auto maybe_converted = compute::Cast(scalar, field_type); + if (!maybe_converted.ok()) { + return Status::TypeError("Field ", to_->field(index)->ToString(), + " cannot be materialized from scalar of type ", + *scalar->type, + ". Cast error: ", maybe_converted.status().message()); + } + scalar = maybe_converted->scalar(); + } else { + scalar = MakeNullScalar(field_type); + } } scalars_[index] = std::move(scalar); diff --git a/cpp/src/arrow/python/arrow_to_pandas.cc b/cpp/src/arrow/python/arrow_to_pandas.cc index 09245285030..1c47f9742de 100644 --- a/cpp/src/arrow/python/arrow_to_pandas.cc +++ b/cpp/src/arrow/python/arrow_to_pandas.cc @@ -2183,7 +2183,9 @@ Status ConvertCategoricals(const PandasOptions& options, ChunkedArrayVector* arr "only zero-copy conversions allowed"); } compute::ExecContext ctx(options.pool); - ARROW_ASSIGN_OR_RAISE(Datum out, DictionaryEncode((*arrays)[i], &ctx)); + ARROW_ASSIGN_OR_RAISE( + Datum out, DictionaryEncode((*arrays)[i], + compute::DictionaryEncodeOptions::Defaults(), &ctx)); (*arrays)[i] = out.chunked_array(); (*fields)[i] = (*fields)[i]->WithType((*arrays)[i]->type()); return Status::OK(); @@ -2232,7 +2234,9 @@ Status ConvertChunkedArrayToPandas(const PandasOptions& options, "only zero-copy conversions allowed"); } compute::ExecContext ctx(options.pool); - ARROW_ASSIGN_OR_RAISE(Datum out, DictionaryEncode(arr, &ctx)); + ARROW_ASSIGN_OR_RAISE( + Datum out, + DictionaryEncode(arr, compute::DictionaryEncodeOptions::Defaults(), &ctx)); arr = out.chunked_array(); } diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index e5a19288b87..3cb152aa381 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -649,6 +649,32 @@ class FilterOptions(_FilterOptions): self._set_options(null_selection_behavior) +cdef class _DictionaryEncodeOptions(FunctionOptions): + cdef: + unique_ptr[CDictionaryEncodeOptions] dictionary_encode_options + + cdef const CFunctionOptions* get_options(self) except NULL: + return self.dictionary_encode_options.get() + + def _set_options(self, null_encoding_behavior): + if null_encoding_behavior == 'encode': + self.dictionary_encode_options.reset( + new CDictionaryEncodeOptions( + CDictionaryEncodeNullEncodingBehavior_ENCODE)) + elif null_encoding_behavior == 'mask': + self.dictionary_encode_options.reset( + new CDictionaryEncodeOptions( + CDictionaryEncodeNullEncodingBehavior_MASK)) + else: + raise ValueError('"{}" is not a valid null_encoding_behavior' + .format(null_encoding_behavior)) + + +class DictionaryEncodeOptions(_DictionaryEncodeOptions): + def __init__(self, null_encoding_behavior='mask'): + self._set_options(null_encoding_behavior) + + cdef class _TakeOptions(FunctionOptions): cdef: unique_ptr[CTakeOptions] take_options diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index c67dbc99d77..1c4e5d302c5 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -206,6 +206,10 @@ cdef class Expression(_Weakrefable): """Checks whether the expression is not-null (valid)""" return Expression._call("is_valid", [self]) + def is_null(self): + """Checks whether the expression is null""" + return Expression._call("is_null", [self]) + def cast(self, type, bint safe=True): """Explicitly change the expression's data type""" cdef shared_ptr[CCastOptions] c_options @@ -1546,7 +1550,7 @@ cdef class DirectoryPartitioning(Partitioning): Returns ------- - DirectoryPartitioningFactory + PartitioningFactory To be used in the FileSystemFactoryOptions. """ cdef: @@ -1590,6 +1594,8 @@ cdef class HivePartitioning(Partitioning): corresponding entry of `dictionaries` must be an array containing every value which may be taken by the corresponding column or an error will be raised in parsing. + null_fallback : str, default "__HIVE_DEFAULT_PARTITION__" + If any field is None then this fallback will be used as a label Returns ------- @@ -1608,13 +1614,19 @@ cdef class HivePartitioning(Partitioning): cdef: CHivePartitioning* hive_partitioning - def __init__(self, Schema schema not None, dictionaries=None): + def __init__(self, + Schema schema not None, + dictionaries=None, + null_fallback="__HIVE_DEFAULT_PARTITION__"): + cdef: shared_ptr[CHivePartitioning] c_partitioning + c_string c_null_fallback = tobytes(null_fallback) c_partitioning = make_shared[CHivePartitioning]( pyarrow_unwrap_schema(schema), - _partitioning_dictionaries(schema, dictionaries) + _partitioning_dictionaries(schema, dictionaries), + c_null_fallback ) self.init( c_partitioning) @@ -1623,7 +1635,9 @@ cdef class HivePartitioning(Partitioning): self.hive_partitioning = sp.get() @staticmethod - def discover(infer_dictionary=False, max_partition_dictionary_size=0): + def discover(infer_dictionary=False, + max_partition_dictionary_size=0, + null_fallback="__HIVE_DEFAULT_PARTITION__"): """ Discover a HivePartitioning. @@ -1639,6 +1653,10 @@ cdef class HivePartitioning(Partitioning): Synonymous with infer_dictionary for backwards compatibility with 1.0: setting this to -1 or None is equivalent to passing infer_dictionary=True. + null_fallback : str, default "__HIVE_DEFAULT_PARTITION__" + When inferring a schema for partition fields this value will be + replaced by null. The default is set to __HIVE_DEFAULT_PARTITION__ + for compatibility with Spark Returns ------- @@ -1646,7 +1664,7 @@ cdef class HivePartitioning(Partitioning): To be used in the FileSystemFactoryOptions. """ cdef: - CPartitioningFactoryOptions c_options + CHivePartitioningFactoryOptions c_options if max_partition_dictionary_size in {-1, None}: infer_dictionary = True @@ -1657,6 +1675,8 @@ cdef class HivePartitioning(Partitioning): if infer_dictionary: c_options.infer_dictionary = True + c_options.null_fallback = tobytes(null_fallback) + return PartitioningFactory.wrap( CHivePartitioning.MakeFactory(c_options)) diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index ae9e213b98d..a832b00b1eb 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -842,11 +842,12 @@ cdef class Array(_PandasConvertible): """ return _pc().call_function('unique', [self]) - def dictionary_encode(self): + def dictionary_encode(self, null_encoding='mask'): """ Compute dictionary-encoded representation of array. """ - return _pc().call_function('dictionary_encode', [self]) + options = _pc().DictionaryEncodeOptions(null_encoding) + return _pc().call_function('dictionary_encode', [self], options) def value_counts(self): """ diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index 616b2de89ec..3d7f5ecb4c3 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -30,6 +30,7 @@ ArraySortOptions, CastOptions, CountOptions, + DictionaryEncodeOptions, FilterOptions, MatchSubstringOptions, MinMaxOptions, diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index e10ef1e3a5e..ba3c3ad7d2b 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1802,6 +1802,20 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: CFilterOptions(CFilterNullSelectionBehavior null_selection) CFilterNullSelectionBehavior null_selection_behavior + enum CDictionaryEncodeNullEncodingBehavior \ + "arrow::compute::DictionaryEncodeOptions::NullEncodingBehavior": + CDictionaryEncodeNullEncodingBehavior_ENCODE \ + "arrow::compute::DictionaryEncodeOptions::ENCODE" + CDictionaryEncodeNullEncodingBehavior_MASK \ + "arrow::compute::DictionaryEncodeOptions::MASK" + + cdef cppclass CDictionaryEncodeOptions \ + "arrow::compute::DictionaryEncodeOptions"(CFunctionOptions): + CDictionaryEncodeOptions() + CDictionaryEncodeOptions( + CDictionaryEncodeNullEncodingBehavior null_encoding) + CDictionaryEncodeNullEncodingBehavior null_encoding + cdef cppclass CTakeOptions \ " arrow::compute::TakeOptions"(CFunctionOptions): CTakeOptions(c_bool boundscheck) diff --git a/python/pyarrow/includes/libarrow_dataset.pxd b/python/pyarrow/includes/libarrow_dataset.pxd index 29f9738dedc..93bc0edddc1 100644 --- a/python/pyarrow/includes/libarrow_dataset.pxd +++ b/python/pyarrow/includes/libarrow_dataset.pxd @@ -274,6 +274,11 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: "arrow::dataset::PartitioningFactoryOptions": c_bool infer_dictionary + cdef cppclass CHivePartitioningFactoryOptions \ + "arrow::dataset::HivePartitioningFactoryOptions": + c_bool infer_dictionary, + c_string null_fallback + cdef cppclass CPartitioningFactory "arrow::dataset::PartitioningFactory": pass @@ -293,7 +298,7 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: @staticmethod shared_ptr[CPartitioningFactory] MakeFactory( - CPartitioningFactoryOptions) + CHivePartitioningFactoryOptions) cdef cppclass CPartitioningOrFactory \ "arrow::dataset::PartitioningOrFactory": diff --git a/python/pyarrow/public-api.pxi b/python/pyarrow/public-api.pxi index aa738f9aaea..998af512c55 100644 --- a/python/pyarrow/public-api.pxi +++ b/python/pyarrow/public-api.pxi @@ -251,6 +251,9 @@ cdef api object pyarrow_wrap_scalar(const shared_ptr[CScalar]& sp_scalar): if data_type == NULL: raise ValueError('Scalar data type was NULL') + if data_type.id() == _Type_NA: + return _NULL + if data_type.id() not in _scalar_classes: raise ValueError('Scalar type not supported') diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index c6b0b4180b6..3f1fc28ee60 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -276,7 +276,7 @@ cdef class ChunkedArray(_PandasConvertible): """ return _pc().cast(self, target_type, safe=safe) - def dictionary_encode(self): + def dictionary_encode(self, null_encoding='mask'): """ Compute dictionary-encoded representation of array @@ -285,7 +285,8 @@ cdef class ChunkedArray(_PandasConvertible): pyarrow.ChunkedArray Same chunking as the input, all chunks share a common dictionary. """ - return _pc().call_function('dictionary_encode', [self]) + options = _pc().DictionaryEncodeOptions(null_encoding) + return _pc().call_function('dictionary_encode', [self], options) def flatten(self, MemoryPool memory_pool=None): """ diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index 796f6d998e8..57179f391de 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -16,6 +16,8 @@ # under the License. import contextlib +import os +import posixpath import pathlib import pickle import textwrap @@ -381,11 +383,16 @@ def test_partitioning(): with pytest.raises(pa.ArrowInvalid): partitioning.parse('/prefix/3/aaa') + expr = partitioning.parse('/3') + expected = ds.field('group') == 3 + assert expr.equals(expected) + partitioning = ds.HivePartitioning( pa.schema([ pa.field('alpha', pa.int64()), pa.field('beta', pa.int64()) - ]) + ]), + null_fallback='xyz' ) expr = partitioning.parse('/alpha=0/beta=3') expected = ( @@ -394,6 +401,12 @@ def test_partitioning(): ) assert expr.equals(expected) + expr = partitioning.parse('/alpha=xyz/beta=3') + expected = ( + (ds.field('alpha').is_null() & (ds.field('beta') == ds.scalar(3))) + ) + assert expr.equals(expected) + for shouldfail in ['/alpha=one/beta=2', '/alpha=one', '/beta=two']: with pytest.raises(pa.ArrowInvalid): partitioning.parse(shouldfail) @@ -412,7 +425,7 @@ def test_expression_serialization(): d.is_valid(), a.cast(pa.int32(), safe=False), a.cast(pa.int32(), safe=False), a.isin([1, 2, 3]), ds.field('i64') > 5, ds.field('i64') == 5, - ds.field('i64') == 7] + ds.field('i64') == 7, ds.field('i64').is_null()] for expr in all_exprs: assert isinstance(expr, ds.Expression) restored = pickle.loads(pickle.dumps(expr)) @@ -468,6 +481,9 @@ def test_partition_keys(): assert ds._get_partition_keys(nope) == {} assert ds._get_partition_keys(a & nope) == {'a': 'a'} + null = ds.field('a').is_null() + assert ds._get_partition_keys(null) == {'a': None} + def test_parquet_read_options(): opts1 = ds.ParquetReadOptions() @@ -1239,6 +1255,57 @@ def test_partitioning_factory_dictionary(mockfs, infer_dictionary): assert inferred_schema.field('key').type == pa.string() +def test_dictionary_partitioning_outer_nulls_raises(tempdir): + table = pa.table({'a': ['x', 'y', None], 'b': ['x', 'y', 'z']}) + part = ds.partitioning( + pa.schema([pa.field('a', pa.string()), pa.field('b', pa.string())])) + with pytest.raises(pa.ArrowInvalid): + ds.write_dataset(table, tempdir, format='parquet', partitioning=part) + + +def _has_subdirs(basedir): + elements = os.listdir(basedir) + return any([os.path.isdir(os.path.join(basedir, el)) for el in elements]) + + +def _do_list_all_dirs(basedir, path_so_far, result): + for f in os.listdir(basedir): + true_nested = os.path.join(basedir, f) + if os.path.isdir(true_nested): + norm_nested = posixpath.join(path_so_far, f) + if _has_subdirs(true_nested): + _do_list_all_dirs(true_nested, norm_nested, result) + else: + result.append(norm_nested) + + +def _list_all_dirs(basedir): + result = [] + _do_list_all_dirs(basedir, '', result) + return result + + +def _check_dataset_directories(tempdir, expected_directories): + actual_directories = set(_list_all_dirs(tempdir)) + assert actual_directories == set(expected_directories) + + +def test_dictionary_partitioning_inner_nulls(tempdir): + table = pa.table({'a': ['x', 'y', 'z'], 'b': ['x', 'y', None]}) + part = ds.partitioning( + pa.schema([pa.field('a', pa.string()), pa.field('b', pa.string())])) + ds.write_dataset(table, tempdir, format='parquet', partitioning=part) + _check_dataset_directories(tempdir, ['x/x', 'y/y', 'z']) + + +def test_hive_partitioning_nulls(tempdir): + table = pa.table({'a': ['x', None, 'z'], 'b': ['x', 'y', None]}) + part = ds.HivePartitioning(pa.schema( + [pa.field('a', pa.string()), pa.field('b', pa.string())]), None, 'xyz') + ds.write_dataset(table, tempdir, format='parquet', partitioning=part) + _check_dataset_directories(tempdir, ['a=x/b=x', 'a=xyz/b=y', 'a=z/b=xyz']) + + def test_partitioning_function(): schema = pa.schema([("year", pa.int16()), ("month", pa.int8())]) names = ["year", "month"] @@ -1600,25 +1667,48 @@ def test_open_dataset_non_existing_file(): @pytest.mark.parquet @pytest.mark.parametrize('partitioning', ["directory", "hive"]) +@pytest.mark.parametrize('null_fallback', ['xyz', None]) +@pytest.mark.parametrize('infer_dictionary', [False, True]) @pytest.mark.parametrize('partition_keys', [ (["A", "B", "C"], [1, 2, 3]), ([1, 2, 3], ["A", "B", "C"]), (["A", "B", "C"], ["D", "E", "F"]), ([1, 2, 3], [4, 5, 6]), + ([1, None, 3], ["A", "B", "C"]), + ([1, 2, 3], ["A", None, "C"]), + ([None, 2, 3], [None, 2, 3]), ]) -def test_open_dataset_partitioned_dictionary_type(tempdir, partitioning, - partition_keys): +def test_partition_discovery( + tempdir, partitioning, null_fallback, infer_dictionary, partition_keys +): # ARROW-9288 / ARROW-9476 import pyarrow.parquet as pq - table = pa.table({'a': range(9), 'b': [0.] * 4 + [1.] * 5}) + + table = pa.table({'a': range(9), 'b': [0.0] * 4 + [1.0] * 5}) + + has_null = None in partition_keys[0] or None in partition_keys[1] + if partitioning == "directory" and has_null: + # Directory partitioning can't handle the first part being null + return if partitioning == "directory": partitioning = ds.DirectoryPartitioning.discover( - ["part1", "part2"], infer_dictionary=True) + ["part1", "part2"], infer_dictionary=infer_dictionary) fmt = "{0}/{1}" + null_value = None else: - partitioning = ds.HivePartitioning.discover(infer_dictionary=True) + if null_fallback: + partitioning = ds.HivePartitioning.discover( + infer_dictionary=infer_dictionary, null_fallback=null_fallback + ) + else: + partitioning = ds.HivePartitioning.discover( + infer_dictionary=infer_dictionary) fmt = "part1={0}/part2={1}" + if null_fallback: + null_value = null_fallback + else: + null_value = "__HIVE_DEFAULT_PARTITION__" basepath = tempdir / "dataset" basepath.mkdir() @@ -1626,19 +1716,23 @@ def test_open_dataset_partitioned_dictionary_type(tempdir, partitioning, part_keys1, part_keys2 = partition_keys for part1 in part_keys1: for part2 in part_keys2: - path = basepath / fmt.format(part1, part2) + path = basepath / \ + fmt.format(part1 or null_value, part2 or null_value) path.mkdir(parents=True) pq.write_table(table, path / "test.parquet") dataset = ds.dataset(str(basepath), partitioning=partitioning) - def dict_type(key): - value_type = pa.string() if isinstance(key, str) else pa.int32() - return pa.dictionary(pa.int32(), value_type) + def expected_type(key): + if infer_dictionary: + value_type = pa.string() if isinstance(key, str) else pa.int32() + return pa.dictionary(pa.int32(), value_type) + else: + return pa.string() if isinstance(key, str) else pa.int32() expected_schema = table.schema.append( - pa.field("part1", dict_type(part_keys1[0])) + pa.field("part1", expected_type(part_keys1[0])) ).append( - pa.field("part2", dict_type(part_keys2[0])) + pa.field("part2", expected_type(part_keys2[0])) ) assert dataset.schema.equals(expected_schema) @@ -2304,6 +2398,52 @@ def test_dataset_project_only_partition_columns(tempdir): assert all_cols.column('part').equals(part_only.column('part')) +@pytest.mark.parquet +@pytest.mark.pandas +def test_write_to_dataset_given_null_just_works(tempdir): + import pyarrow.parquet as pq + + schema = pa.schema([ + pa.field('col', pa.int64()), + pa.field('part', pa.dictionary(pa.int32(), pa.string())) + ]) + table = pa.table({'part': [None, None, 'a', 'a'], + 'col': list(range(4))}, schema=schema) + + path = str(tempdir / 'test_dataset') + pq.write_to_dataset(table, path, partition_cols=[ + 'part'], use_legacy_dataset=False) + + actual_table = pq.read_table(tempdir / 'test_dataset') + # column.equals can handle the difference in chunking but not the fact + # that `part` will have different dictionaries for the two chunks + assert actual_table.column('part').to_pylist( + ) == table.column('part').to_pylist() + assert actual_table.column('col').equals(table.column('col')) + + +@pytest.mark.parquet +@pytest.mark.pandas +def test_legacy_write_to_dataset_drops_null(tempdir): + import pyarrow.parquet as pq + + schema = pa.schema([ + pa.field('col', pa.int64()), + pa.field('part', pa.dictionary(pa.int32(), pa.string())) + ]) + table = pa.table({'part': ['a', 'a', None, None], + 'col': list(range(4))}, schema=schema) + expected = pa.table( + {'part': ['a', 'a'], 'col': list(range(2))}, schema=schema) + + path = str(tempdir / 'test_dataset') + pq.write_to_dataset(table, path, partition_cols=[ + 'part'], use_legacy_dataset=True) + + actual = pq.read_table(tempdir / 'test_dataset') + assert actual == expected + + @pytest.mark.parquet @pytest.mark.pandas def test_dataset_project_null_column(tempdir):