Skip to content
14 changes: 10 additions & 4 deletions cpp/src/arrow/compute/kernels/aggregate_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,27 +136,33 @@ struct CountDistinctImpl : public ScalarAggregator {
Status Consume(KernelContext*, const ExecBatch& batch) override {
if (batch[0].is_array()) {
const ArrayData& arr = *batch[0].array();
this->has_nulls = arr.GetNullCount() > 0;

auto visit_null = []() { return Status::OK(); };
auto visit_value = [&](VisitorArgType arg) {
int y;
return memo_table_->GetOrInsert(arg, &y);
};
RETURN_NOT_OK(VisitArraySpanInline<Type>(arr, visit_value, visit_null));
this->non_nulls += memo_table_->size();
this->has_nulls = arr.GetNullCount() > 0;

} else {
const Scalar& input = *batch[0].scalar();
this->has_nulls = !input.is_valid;

if (input.is_valid) {
this->non_nulls += batch.length;
RETURN_NOT_OK(memo_table_->MaybeInsert(UnboxScalar<Type>::Unbox(input)));
}
}

this->non_nulls = memo_table_->size();

return Status::OK();
}

Status MergeFrom(KernelContext*, KernelState&& src) override {
const auto& other_state = checked_cast<const CountDistinctImpl&>(src);
this->non_nulls += other_state.non_nulls;
this->memo_table_->MergeTable(*(other_state.memo_table_));
this->non_nulls = this->memo_table_->size();
this->has_nulls = this->has_nulls || other_state.has_nulls;
return Status::OK();
}
Expand Down
72 changes: 72 additions & 0 deletions cpp/src/arrow/compute/kernels/aggregate_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -962,11 +962,83 @@ class TestCountDistinctKernel : public ::testing::Test {
EXPECT_THAT(CallFunction("count_distinct", {input}, &all), one);
}

void CheckChunkedArr(const std::shared_ptr<DataType>& type,
const std::vector<std::string>& json, int64_t expected_all,
bool has_nulls = true) {
Check(ChunkedArrayFromJSON(type, json), expected_all, has_nulls);
}

CountOptions only_valid{CountOptions::ONLY_VALID};
CountOptions only_null{CountOptions::ONLY_NULL};
CountOptions all{CountOptions::ALL};
};

TEST_F(TestCountDistinctKernel, AllChunkedArrayTypesWithNulls) {
// Boolean
CheckChunkedArr(boolean(), {"[]", "[]"}, 0, /*has_nulls=*/false);
CheckChunkedArr(boolean(), {"[true, null]", "[false, null, false]", "[true]"}, 3);

// Number
for (auto ty : NumericTypes()) {
CheckChunkedArr(ty, {"[1, 1, null, 2]", "[5, 8, 9, 9, null, 10]", "[6, 6, 8, 9, 10]"},
8);
CheckChunkedArr(ty, {"[1, 1, 8, 2]", "[5, 8, 9, 9, 10]", "[10, 6, 6]"}, 7,
/*has_nulls=*/false);
}

// Date
CheckChunkedArr(date32(), {"[0, 11016]", "[0, null, 14241, 14241, null]"}, 4);
CheckChunkedArr(date64(), {"[0, null]", "[0, null, 0, 0, 1262217600000]"}, 3);

// Time
CheckChunkedArr(time32(TimeUnit::SECOND), {"[ 0, 11, 0, null]", "[14, 14, null]"}, 4);
CheckChunkedArr(time32(TimeUnit::MILLI), {"[ 0, 11000, 0]", "[null, 11000, 11000]"}, 3);

CheckChunkedArr(time64(TimeUnit::MICRO), {"[84203999999, 0, null, 84203999999]", "[0]"},
3);
CheckChunkedArr(time64(TimeUnit::NANO),
{"[11715003000000, 0, null, 0, 0]", "[0, 0, null]"}, 3);

// Timestamp & Duration
for (auto u : TimeUnit::values()) {
CheckChunkedArr(duration(u), {"[123456789, null, 987654321]", "[123456789, null]"},
3);

CheckChunkedArr(duration(u),
{"[123456789, 987654321, 123456789, 123456789]", "[123456789]"}, 2,
/*has_nulls=*/false);

auto ts =
std::vector<std::string>{R"(["2009-12-31T04:20:20", "2009-12-31T04:20:20"])",
R"(["2020-01-01", null])", R"(["2020-01-01", null])"};
CheckChunkedArr(timestamp(u), ts, 3);
CheckChunkedArr(timestamp(u, "Pacific/Marquesas"), ts, 3);
}

// Interval
CheckChunkedArr(month_interval(), {"[9012, 5678, null, 9012]", "[5678, null, 9012]"},
3);
CheckChunkedArr(day_time_interval(),
{"[[0, 1], [0, 1]]", "[null, [0, 1], [1234, 5678]]"}, 3);
CheckChunkedArr(month_day_nano_interval(),
{"[[0, 1, 2]]", "[[0, 1, 2], null, [0, 1, 2]]"}, 2);

// Binary & String & Fixed binary
auto samples = std::vector<std::string>{
R"([null, "abc", null])", R"(["abc", "abc", "cba"])", R"(["bca", "cba", null])"};

CheckChunkedArr(binary(), samples, 4);
CheckChunkedArr(large_binary(), samples, 4);
CheckChunkedArr(utf8(), samples, 4);
CheckChunkedArr(large_utf8(), samples, 4);
CheckChunkedArr(fixed_size_binary(3), samples, 4);

// Decimal
samples = {R"(["12345.679", "98765.421"])", R"([null, "12345.679", "98765.421"])"};
CheckChunkedArr(decimal128(21, 3), samples, 3);
CheckChunkedArr(decimal256(13, 3), samples, 3);
}

TEST_F(TestCountDistinctKernel, AllArrayTypesWithNulls) {
// Boolean
Check(boolean(), "[]", 0, /*has_nulls=*/false);
Expand Down
16 changes: 16 additions & 0 deletions cpp/src/arrow/compute/kernels/codegen_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,22 @@ struct UnboxScalar<Type, enable_if_has_string_view<Type>> {
using T = util::string_view;
static T Unbox(const Scalar& val) {
if (!val.is_valid) return util::string_view();

switch (val.type->id()) {
case arrow::Type::DECIMAL128: {
return util::string_view(checked_cast<const Decimal128Scalar&>(val).view());
break;
}

case arrow::Type::DECIMAL256: {
return util::string_view(checked_cast<const Decimal256Scalar&>(val).view());
break;
}

default:
break;
}

return util::string_view(*checked_cast<const BaseBinaryScalar&>(val).value);
}
};
Expand Down
81 changes: 81 additions & 0 deletions cpp/src/arrow/util/hashing.h
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,22 @@ class ScalarMemoTable : public MemoTable {
value, [](int32_t i) {}, [](int32_t i) {}, out_memo_index);
}

Status MaybeInsert(const Scalar& value) {
Comment thread
drin marked this conversation as resolved.
Outdated
auto cmp_func = [value](const Payload* payload) -> bool {
return ScalarHelper<Scalar, 0>::CompareScalars(value, payload->value);
};

hash_t val_hash = ComputeHash(value);
auto hash_entry = hash_table_.Lookup(val_hash, cmp_func);

// Insert if it wasn't found; otherwise, we're done
if (!hash_entry.second) {
RETURN_NOT_OK(hash_table_.Insert(hash_entry.first, val_hash, {value, size()}));
}

return Status::OK();
}

int32_t GetNull() const { return null_index_; }

template <typename Func1, typename Func2>
Expand Down Expand Up @@ -485,6 +501,18 @@ class ScalarMemoTable : public MemoTable {
hash_t ComputeHash(const Scalar& value) const {
return ScalarHelper<Scalar, 0>::ComputeHash(value);
}

public:
// defined here so that `HashTableType` is visible
// Merge entries from `other_table` into `this->hash_table_`.
void MergeTable(ScalarMemoTable& other_table) {
HashTableType& other_hashtable = other_table.hash_table_;
Comment thread
drin marked this conversation as resolved.
Outdated

other_hashtable.VisitEntries([=](const HashTableEntry* other_entry) {
ARROW_WARN_NOT_OK(this->MaybeInsert(other_entry->payload.value),
"Merging ScalarMemoTable");
});
}
};

// ----------------------------------------------------------------------
Expand Down Expand Up @@ -545,6 +573,22 @@ class SmallScalarMemoTable : public MemoTable {
value, [](int32_t i) {}, [](int32_t i) {}, out_memo_index);
}

Status MaybeInsert(const Scalar& value) {
auto value_index = AsIndex(value);
auto memo_index = value_to_index_[value_index];

if (memo_index == kKeyNotFound) {
memo_index = static_cast<int32_t>(index_to_value_.size());

index_to_value_.push_back(value);
value_to_index_[value_index] = memo_index;

DCHECK_LT(memo_index, cardinality + 1);
}

return Status::OK();
}

int32_t GetNull() const { return value_to_index_[cardinality]; }

template <typename Func1, typename Func2>
Expand All @@ -568,6 +612,16 @@ class SmallScalarMemoTable : public MemoTable {
// (which is also 1 + the largest memo index)
int32_t size() const override { return static_cast<int32_t>(index_to_value_.size()); }

// Merge entries from `other_table` into `this`.
void MergeTable(SmallScalarMemoTable& other_table) {
for (const Scalar& other_val : other_table.index_to_value_) {
auto insert_status = this->MaybeInsert(other_val);
if (not insert_status.ok()) {
ARROW_WARN_NOT_OK(insert_status, "Merging SmallScalarMemoTable");
}
}
}

// Copy values starting from index `start` into `out_data`
void CopyValues(int32_t start, Scalar* out_data) const {
DCHECK_GE(start, 0);
Expand Down Expand Up @@ -683,6 +737,26 @@ class BinaryMemoTable : public MemoTable {
return GetOrInsertNull([](int32_t i) {}, [](int32_t i) {});
}

Status MaybeInsert(const util::string_view& value) {
const void* val_data = value.data();
auto val_length = static_cast<builder_offset_type>(value.length());

hash_t val_hash = ComputeStringHash<0>(val_data, val_length);
auto hash_entry = Lookup(val_hash, val_data, val_length);

if (!hash_entry.second) {
// Insert string value
RETURN_NOT_OK(
binary_builder_.Append(static_cast<const char*>(val_data), val_length));

// Insert hash entry
RETURN_NOT_OK(hash_table_.Insert(const_cast<HashTableEntry*>(hash_entry.first),
val_hash, {size()}));
}

return Status::OK();
}

// The number of entries in the memo table
// (which is also 1 + the largest memo index)
int32_t size() const override {
Expand Down Expand Up @@ -824,6 +898,13 @@ class BinaryMemoTable : public MemoTable {
};
return hash_table_.Lookup(h, cmp_func);
}

public:
void MergeTable(BinaryMemoTable& other_table) {
other_table.VisitValues(0, [=](const util::string_view& other_value) {
Comment thread
drin marked this conversation as resolved.
Outdated
ARROW_WARN_NOT_OK(this->MaybeInsert(other_value), "Merging BinaryMemoTable");
Comment thread
drin marked this conversation as resolved.
Outdated
});
}
};

template <typename T, typename Enable = void>
Expand Down
9 changes: 9 additions & 0 deletions r/tests/testthat/test-dplyr-summarize.R
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,15 @@ test_that("Group by any/all", {
)
})

test_that("n_distinct() with many batches", {
tf <- tempfile()
Comment thread
drin marked this conversation as resolved.
Outdated
write_parquet(dplyr::starwars, tf, chunk_size = 20)

ds <- open_dataset(tf)
expect_equal(ds %>% summarise(n_distinct(sex, na.rm = FALSE)) %>% collect(),
ds %>% collect() %>% summarise(n_distinct(sex, na.rm = FALSE)))
})

test_that("n_distinct() on dataset", {
# With groupby
compare_dplyr_binding(
Expand Down