Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
a4202a9
Merge/rebase
westonpace Jan 21, 2021
53853b6
WIP commit
westonpace Jan 25, 2021
570de2c
Added tests of vector_hash for inputs with nulls. Added ability to s…
westonpace Jan 26, 2021
807c600
Prevent using dictionary columns as partition columns. It wouldn't w…
westonpace Jan 26, 2021
353ea9d
Addressing PR comments
westonpace Feb 2, 2021
68cf487
Taking out an extraneous using that I missed in the last commit
westonpace Feb 2, 2021
ae0b859
WIP
westonpace Feb 4, 2021
c941bae
Adding the null fallback logic to the python half
westonpace Feb 5, 2021
d502e05
WIP
westonpace Feb 8, 2021
5b18c96
WIP
westonpace Feb 11, 2021
613b286
WIP
westonpace Feb 11, 2021
3f4ec25
Improved null handling in expression/partition a bit
westonpace Feb 11, 2021
79dda1a
Added the python half of the new extract known values
westonpace Feb 12, 2021
de7be7b
Lint
westonpace Feb 12, 2021
cd00e59
Missed a test case
westonpace Feb 12, 2021
4506853
Re-lint, it appears my IDE is using the wrong style file
westonpace Feb 12, 2021
3ca5f34
Final lint pass. Turns out I was relying on black which was messing …
westonpace Feb 12, 2021
982f68c
Added more tests, rounded out a few behaviors
westonpace Feb 15, 2021
07eee3a
Added tests for SetDefaultValues to ensure it does the correct thing …
westonpace Feb 15, 2021
8f1792d
Cleaned up logic for valid but not known case
westonpace Feb 16, 2021
6f7ced5
Fixing compiler warning
westonpace Feb 16, 2021
212c9bc
Python lint
westonpace Feb 16, 2021
9ef4a71
Addressing PR comments
westonpace Feb 22, 2021
c54c55d
Update cpp/src/arrow/compute/kernels/vector_hash.cc
westonpace Feb 22, 2021
9b0f8ee
Update cpp/src/arrow/compute/kernels/vector_hash.cc
westonpace Feb 22, 2021
ce53d4e
Update cpp/src/arrow/compute/kernels/vector_hash_test.cc
westonpace Feb 22, 2021
c2aa3ad
Update cpp/src/arrow/dataset/partition_test.cc
westonpace Feb 22, 2021
7d5de82
Update python/pyarrow/_dataset.pyx
westonpace Feb 22, 2021
f1a6759
Added test case to probe what happens when inferring a partition colu…
westonpace Feb 22, 2021
dadbe8b
Use null scalars for known-null fields
bkietz Feb 19, 2021
d3bfe09
constexpr not supported in this context in all gcc versions due to gc…
westonpace Feb 23, 2021
f18c701
Missed one of the merge conflicts
westonpace Feb 23, 2021
591021e
Putting in suggestion from Ben. It got lost on rebase / force-push
westonpace Feb 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions cpp/src/arrow/compute/api_vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,9 @@ Result<std::shared_ptr<Array>> Unique(const Datum& value, ExecContext* ctx) {
return result.make_array();
}

Result<Datum> DictionaryEncode(const Datum& value, ExecContext* ctx) {
return CallFunction("dictionary_encode", {value}, ctx);
Result<Datum> DictionaryEncode(const Datum& value, const DictionaryEncodeOptions& options,
ExecContext* ctx) {
return CallFunction("dictionary_encode", {value}, &options, ctx);
}

const char kValuesFieldName[] = "values";
Expand Down
35 changes: 34 additions & 1 deletion cpp/src/arrow/compute/api_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,24 @@ enum class SortOrder {
Descending,
};

/// \brief Options for the dictionary encode function
struct DictionaryEncodeOptions : public FunctionOptions {
/// Configure how null values will be encoded
enum NullEncodingBehavior {
/// the null value will be added to the dictionary with a proper index
ENCODE,
/// the null value will be masked in the indices array
MASK
};

explicit DictionaryEncodeOptions(NullEncodingBehavior null_encoding = MASK)
: null_encoding_behavior(null_encoding) {}

static DictionaryEncodeOptions Defaults() { return DictionaryEncodeOptions(); }

NullEncodingBehavior null_encoding_behavior = MASK;
};

/// \brief One sort key for PartitionNthIndices (TODO) and SortIndices
struct ARROW_EXPORT SortKey {
explicit SortKey(std::string name, SortOrder order = SortOrder::Ascending)
Expand Down Expand Up @@ -289,14 +307,29 @@ Result<std::shared_ptr<StructArray>> ValueCounts(const Datum& value,
ExecContext* ctx = NULLPTR);

/// \brief Dictionary-encode values in an array-like object
///
/// Any nulls encountered in the dictionary will be handled according to the
/// specified null encoding behavior.
///
/// For example, given values ["a", "b", null, "a", null] the output will be
/// (null_encoding == ENCODE) Indices: [0, 1, 2, 0, 2] / Dict: ["a", "b", null]
/// (null_encoding == MASK) Indices: [0, 1, null, 0, null] / Dict: ["a", "b"]
///
/// If the input is already dictionary encoded this function is a no-op unless
/// it needs to modify the null_encoding (TODO)
///
/// \param[in] data array-like input
/// \param[in] ctx the function execution context, optional
/// \param[in] options configures null encoding behavior
/// \return result with same shape and type as input
///
/// \since 1.0.0
/// \note API not yet finalized
ARROW_EXPORT
Result<Datum> DictionaryEncode(const Datum& data, ExecContext* ctx = NULLPTR);
Result<Datum> DictionaryEncode(
const Datum& data,
const DictionaryEncodeOptions& options = DictionaryEncodeOptions::Defaults(),
ExecContext* ctx = NULLPTR);

// ----------------------------------------------------------------------
// Deprecated functions
Expand Down
116 changes: 84 additions & 32 deletions cpp/src/arrow/compute/kernels/vector_hash.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ class UniqueAction final : public ActionBase {
using ActionBase::ActionBase;

static constexpr bool with_error_status = false;
static constexpr bool with_memo_visit_null = true;

UniqueAction(const std::shared_ptr<DataType>& type, const FunctionOptions* options,
MemoryPool* pool)
: ActionBase(type, pool) {}

Status Reset() { return Status::OK(); }

Expand All @@ -76,6 +79,8 @@ class UniqueAction final : public ActionBase {
template <class Index>
void ObserveNotFound(Index index) {}

bool ShouldEncodeNulls() { return true; }

Status Flush(Datum* out) { return Status::OK(); }

Status FlushFinal(Datum* out) { return Status::OK(); }
Expand All @@ -89,9 +94,9 @@ class ValueCountsAction final : ActionBase {
using ActionBase::ActionBase;

static constexpr bool with_error_status = true;
static constexpr bool with_memo_visit_null = true;

ValueCountsAction(const std::shared_ptr<DataType>& type, MemoryPool* pool)
ValueCountsAction(const std::shared_ptr<DataType>& type, const FunctionOptions* options,
MemoryPool* pool)
: ActionBase(type, pool), count_builder_(pool) {}

Status Reserve(const int64_t length) {
Expand Down Expand Up @@ -147,6 +152,8 @@ class ValueCountsAction final : ActionBase {
}
}

bool ShouldEncodeNulls() const { return true; }

private:
Int64Builder count_builder_;
};
Expand All @@ -159,10 +166,14 @@ class DictEncodeAction final : public ActionBase {
using ActionBase::ActionBase;

static constexpr bool with_error_status = false;
static constexpr bool with_memo_visit_null = false;

DictEncodeAction(const std::shared_ptr<DataType>& type, MemoryPool* pool)
: ActionBase(type, pool), indices_builder_(pool) {}
DictEncodeAction(const std::shared_ptr<DataType>& type, const FunctionOptions* options,
MemoryPool* pool)
: ActionBase(type, pool), indices_builder_(pool) {
if (auto options_ptr = static_cast<const DictionaryEncodeOptions*>(options)) {
encode_options_ = *options_ptr;
}
}

Status Reset() {
indices_builder_.Reset();
Expand All @@ -173,12 +184,16 @@ class DictEncodeAction final : public ActionBase {

template <class Index>
void ObserveNullFound(Index index) {
indices_builder_.UnsafeAppendNull();
if (encode_options_.null_encoding_behavior == DictionaryEncodeOptions::MASK) {
indices_builder_.UnsafeAppendNull();
} else {
indices_builder_.UnsafeAppend(index);
}
}

template <class Index>
void ObserveNullNotFound(Index index) {
indices_builder_.UnsafeAppendNull();
ObserveNullFound(index);
}

template <class Index>
Expand All @@ -191,6 +206,10 @@ class DictEncodeAction final : public ActionBase {
ObserveFound(index);
}

bool ShouldEncodeNulls() {
return encode_options_.null_encoding_behavior == DictionaryEncodeOptions::ENCODE;
}

Status Flush(Datum* out) {
std::shared_ptr<ArrayData> result;
RETURN_NOT_OK(indices_builder_.FinishInternal(&result));
Expand All @@ -202,10 +221,14 @@ class DictEncodeAction final : public ActionBase {

private:
Int32Builder indices_builder_;
DictionaryEncodeOptions encode_options_;
};

class HashKernel : public KernelState {
public:
HashKernel() : options_(nullptr) {}
explicit HashKernel(const FunctionOptions* options) : options_(options) {}

// Reset for another run.
virtual Status Reset() = 0;

Expand All @@ -229,6 +252,7 @@ class HashKernel : public KernelState {
virtual Status Append(const ArrayData& arr) = 0;

protected:
const FunctionOptions* options_;
std::mutex lock_;
};

Expand All @@ -237,12 +261,12 @@ class HashKernel : public KernelState {
// (NullType has a separate implementation)

template <typename Type, typename Scalar, typename Action,
bool with_error_status = Action::with_error_status,
bool with_memo_visit_null = Action::with_memo_visit_null>
bool with_error_status = Action::with_error_status>
class RegularHashKernel : public HashKernel {
public:
RegularHashKernel(const std::shared_ptr<DataType>& type, MemoryPool* pool)
: pool_(pool), type_(type), action_(type, pool) {}
RegularHashKernel(const std::shared_ptr<DataType>& type, const FunctionOptions* options,
MemoryPool* pool)
: HashKernel(options), pool_(pool), type_(type), action_(type, options, pool) {}

Status Reset() override {
memo_table_.reset(new MemoTable(pool_, 0));
Expand Down Expand Up @@ -282,7 +306,7 @@ class RegularHashKernel : public HashKernel {
&unused_memo_index);
},
[this]() {
if (with_memo_visit_null) {
if (action_.ShouldEncodeNulls()) {
auto on_found = [this](int32_t memo_index) {
action_.ObserveNullFound(memo_index);
};
Expand Down Expand Up @@ -318,16 +342,14 @@ class RegularHashKernel : public HashKernel {
[this]() {
// Null
Status s = Status::OK();
if (with_memo_visit_null) {
auto on_found = [this](int32_t memo_index) {
action_.ObserveNullFound(memo_index);
};
auto on_not_found = [this, &s](int32_t memo_index) {
action_.ObserveNullNotFound(memo_index, &s);
};
auto on_found = [this](int32_t memo_index) {
action_.ObserveNullFound(memo_index);
};
auto on_not_found = [this, &s](int32_t memo_index) {
action_.ObserveNullNotFound(memo_index, &s);
};
if (action_.ShouldEncodeNulls()) {
memo_table_->GetOrInsertNull(std::move(on_found), std::move(on_not_found));
} else {
action_.ObserveNullNotFound(-1);
}
return s;
});
Expand All @@ -345,18 +367,23 @@ class RegularHashKernel : public HashKernel {
// ----------------------------------------------------------------------
// Hash kernel implementation for nulls

template <typename Action>
template <typename Action, bool with_error_status = Action::with_error_status>
class NullHashKernel : public HashKernel {
public:
NullHashKernel(const std::shared_ptr<DataType>& type, MemoryPool* pool)
: pool_(pool), type_(type), action_(type, pool) {}
NullHashKernel(const std::shared_ptr<DataType>& type, const FunctionOptions* options,
MemoryPool* pool)
: pool_(pool), type_(type), action_(type, options, pool) {}

Status Reset() override { return action_.Reset(); }

Status Append(const ArrayData& arr) override {
Status Append(const ArrayData& arr) override { return DoAppend(arr); }

template <bool HasError = with_error_status>
enable_if_t<!HasError, Status> DoAppend(const ArrayData& arr) {
RETURN_NOT_OK(action_.Reserve(arr.length));
for (int64_t i = 0; i < arr.length; ++i) {
if (i == 0) {
seen_null_ = true;
action_.ObserveNullNotFound(0);
} else {
action_.ObserveNullFound(0);
Expand All @@ -365,12 +392,31 @@ class NullHashKernel : public HashKernel {
return Status::OK();
}

template <bool HasError = with_error_status>
enable_if_t<HasError, Status> DoAppend(const ArrayData& arr) {
Status s = Status::OK();
RETURN_NOT_OK(action_.Reserve(arr.length));
for (int64_t i = 0; i < arr.length; ++i) {
if (seen_null_ == false && i == 0) {
seen_null_ = true;
action_.ObserveNullNotFound(0, &s);
} else {
action_.ObserveNullFound(0);
}
}
return s;
}

Status Flush(Datum* out) override { return action_.Flush(out); }
Status FlushFinal(Datum* out) override { return action_.FlushFinal(out); }

Status GetDictionary(std::shared_ptr<ArrayData>* out) override {
// TODO(wesm): handle null being a valid dictionary value
auto null_array = std::make_shared<NullArray>(0);
std::shared_ptr<NullArray> null_array;
if (seen_null_) {
null_array = std::make_shared<NullArray>(1);
} else {
null_array = std::make_shared<NullArray>(0);
}
*out = null_array->data();
return Status::OK();
}
Expand All @@ -380,6 +426,7 @@ class NullHashKernel : public HashKernel {
protected:
MemoryPool* pool_;
std::shared_ptr<DataType> type_;
bool seen_null_ = false;
Action action_;
};

Expand Down Expand Up @@ -451,8 +498,8 @@ struct HashKernelTraits<Type, Action, enable_if_has_string_view<Type>> {
template <typename Type, typename Action>
std::unique_ptr<HashKernel> HashInitImpl(KernelContext* ctx, const KernelInitArgs& args) {
using HashKernelType = typename HashKernelTraits<Type, Action>::HashKernel;
auto result = ::arrow::internal::make_unique<HashKernelType>(args.inputs[0].type,
ctx->memory_pool());
auto result = ::arrow::internal::make_unique<HashKernelType>(
args.inputs[0].type, args.options, ctx->memory_pool());
ctx->SetStatus(result->Reset());
return std::move(result);
}
Expand Down Expand Up @@ -507,6 +554,8 @@ KernelInit GetHashInit(Type::type type_id) {
}
}

using DictionaryEncodeState = OptionsWrapper<DictionaryEncodeOptions>;

template <typename Action>
std::unique_ptr<KernelState> DictionaryHashInit(KernelContext* ctx,
const KernelInitArgs& args) {
Expand Down Expand Up @@ -639,9 +688,11 @@ const FunctionDoc value_counts_doc(
"Nulls in the input are ignored."),
{"array"});

const auto kDefaultDictionaryEncodeOptions = DictionaryEncodeOptions::Defaults();
const FunctionDoc dictionary_encode_doc(
"Dictionary-encode array",
("Return a dictionary-encoded version of the input array."), {"array"});
("Return a dictionary-encoded version of the input array."), {"array"},
"DictionaryEncodeOptions");

} // namespace

Expand Down Expand Up @@ -691,7 +742,8 @@ void RegisterVectorHash(FunctionRegistry* registry) {
// Unique and ValueCounts output unchunked arrays
base.output_chunked = true;
auto dict_encode = std::make_shared<VectorFunction>("dictionary_encode", Arity::Unary(),
&dictionary_encode_doc);
&dictionary_encode_doc,
&kDefaultDictionaryEncodeOptions);
AddHashKernels<DictEncodeAction>(dict_encode.get(), base, OutputType(DictEncodeOutput));

// Calling dictionary_encode on dictionary input not supported, but if it
Expand Down
Loading