From ba80e49c4dd7c688b1a722217e5169846de58c0a Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Wed, 16 Nov 2022 08:31:23 -0500 Subject: [PATCH 1/2] ARROW-18342: [C++] AsofJoinNode support for Boolean data field --- cpp/src/arrow/compute/exec/asof_join_node.cc | 18 ++++++++++++++- .../arrow/compute/exec/asof_join_node_test.cc | 23 ++++++++++++++++--- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index aef652e9662..73c8c4ffe2a 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -36,6 +36,7 @@ #include "arrow/result.h" #include "arrow/status.h" #include "arrow/type_traits.h" +#include "arrow/util/bit_util.h" #include "arrow/util/checked_cast.h" #include "arrow/util/future.h" @@ -608,6 +609,7 @@ class CompositeReferenceTable { } switch (field_type->id()) { + ASOFJOIN_MATERIALIZE_CASE(BOOL) ASOFJOIN_MATERIALIZE_CASE(INT8) ASOFJOIN_MATERIALIZE_CASE(INT16) ASOFJOIN_MATERIALIZE_CASE(INT32) @@ -664,12 +666,25 @@ class CompositeReferenceTable { } template ::BuilderType> - enable_if_fixed_width_type static BuilderAppend( + enable_if_boolean static BuilderAppend( Builder& builder, const std::shared_ptr& source, row_index_t row) { if (source->IsNull(row)) { builder.UnsafeAppendNull(); return Status::OK(); } + builder.UnsafeAppend(bit_util::GetBit(source->template GetValues(1), row)); + return Status::OK(); + } + + template ::BuilderType> + enable_if_t::value && !is_boolean_type::value, + Status> static BuilderAppend(Builder& builder, + const std::shared_ptr& source, + row_index_t row) { + if (source->IsNull(row)) { + builder.UnsafeAppendNull(); + return Status::OK(); + } using CType = typename TypeTraits::CType; builder.UnsafeAppend(source->template GetValues(1)[row]); return Status::OK(); @@ -924,6 +939,7 @@ class AsofJoinNode : public ExecNode { static Status is_valid_data_field(const std::shared_ptr& field) { switch (field->type()->id()) { + case Type::BOOL: case Type::INT8: case Type::INT16: case Type::INT32: diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index d3fa6c32f47..4bec6f4a244 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -31,6 +31,7 @@ #include "arrow/compute/exec/util.h" #include "arrow/compute/kernels/row_encoder.h" #include "arrow/compute/kernels/test_util.h" +#include "arrow/testing/generator.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/matchers.h" #include "arrow/testing/random.h" @@ -72,8 +73,9 @@ Result MakeBatchesFromNumString( const std::vector& json_strings, int multiplicity = 1) { FieldVector num_fields; for (auto field : schema->fields()) { - num_fields.push_back( - is_base_binary_like(field->type()->id()) ? field->WithType(int64()) : field); + auto id = field->type()->id(); + bool adjust = id == Type::BOOL || is_base_binary_like(id); + num_fields.push_back(adjust ? field->WithType(int64()) : field); } auto num_schema = std::make_shared(num_fields, schema->endianness(), schema->metadata()); @@ -83,6 +85,7 @@ Result MakeBatchesFromNumString( batches.schema = schema; int n_fields = schema->num_fields(); for (auto num_batch : num_batches.batches) { + Datum two(ConstantArrayGenerator::Int32(num_batch.length, 2)); std::vector values; for (int i = 0; i < n_fields; i++) { auto type = schema->field(i)->type(); @@ -91,6 +94,18 @@ Result MakeBatchesFromNumString( ARROW_ASSIGN_OR_RAISE(Datum as_string, Cast(num_batch.values[i], utf8())); ARROW_ASSIGN_OR_RAISE(Datum as_type, Cast(as_string, type)); values.push_back(as_type); + } else if (Type::BOOL == type->id()) { + // the next 4 lines compute `as_bool` as `(bool)(x - 2*(x/2))`, i.e., the low bit + // of `x`. Here, `x` stands for `num_batch.values[i]`, which is an `int64` value. + // Taking the low bit is a somewhat arbitrary way of obtaining both `true` and + // `false` values from the `int64` values in the test data, in order to get good + // testing coverage. A simple cast to a Boolean value would not get good coverage + // because all positive values would be cast to `true`. + ARROW_ASSIGN_OR_RAISE(Datum div_two, Divide(num_batch.values[i], two)); + ARROW_ASSIGN_OR_RAISE(Datum rounded, Multiply(div_two, two)); + ARROW_ASSIGN_OR_RAISE(Datum low_bit, Subtract(num_batch.values[i], rounded)); + ARROW_ASSIGN_OR_RAISE(Datum as_bool, Cast(low_bit, type)); + values.push_back(as_bool); } else { values.push_back(num_batch.values[i]); } @@ -526,6 +541,7 @@ struct BasicTest { large_utf8(), binary(), large_binary(), + boolean(), int8(), int16(), int32(), @@ -550,7 +566,8 @@ struct BasicTest { // byte_width > 1 below allows fitting the tested data auto time_types = init_types( all_types, [](T& t) { return t->byte_width() > 1 && !is_floating(t->id()); }); - auto key_types = init_types(all_types, [](T& t) { return !is_floating(t->id()); }); + auto key_types = init_types( + all_types, [](T& t) { return !is_floating(t->id()) && t->id() != Type::BOOL; }); auto l_types = init_types(all_types, [](T& t) { return true; }); auto r0_types = init_types(all_types, [](T& t) { return t->byte_width() > 1; }); auto r1_types = init_types(all_types, [](T& t) { return t->byte_width() > 1; }); From 18d5116ce6e87c3bd0e9fa8e122dcd5b78073763 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Thu, 17 Nov 2022 02:34:45 -0500 Subject: [PATCH 2/2] requested fixes --- cpp/src/arrow/compute/exec/asof_join_node.cc | 21 ++++++------------- .../arrow/compute/exec/asof_join_node_test.cc | 3 +-- 2 files changed, 7 insertions(+), 17 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 73c8c4ffe2a..19bd71df44e 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -664,29 +664,20 @@ class CompositeReferenceTable { void AddRecordBatchRef(const std::shared_ptr& ref) { if (!_ptr2ref.count((uintptr_t)ref.get())) _ptr2ref[(uintptr_t)ref.get()] = ref; } - template ::BuilderType> - enable_if_boolean static BuilderAppend( + enable_if_fixed_width_type static BuilderAppend( Builder& builder, const std::shared_ptr& source, row_index_t row) { if (source->IsNull(row)) { builder.UnsafeAppendNull(); return Status::OK(); } - builder.UnsafeAppend(bit_util::GetBit(source->template GetValues(1), row)); - return Status::OK(); - } - template ::BuilderType> - enable_if_t::value && !is_boolean_type::value, - Status> static BuilderAppend(Builder& builder, - const std::shared_ptr& source, - row_index_t row) { - if (source->IsNull(row)) { - builder.UnsafeAppendNull(); - return Status::OK(); + if constexpr (is_boolean_type::value) { + builder.UnsafeAppend(bit_util::GetBit(source->template GetValues(1), row)); + } else { + using CType = typename TypeTraits::CType; + builder.UnsafeAppend(source->template GetValues(1)[row]); } - using CType = typename TypeTraits::CType; - builder.UnsafeAppend(source->template GetValues(1)[row]); return Status::OK(); } diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index 4bec6f4a244..31bc094c52e 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -31,7 +31,6 @@ #include "arrow/compute/exec/util.h" #include "arrow/compute/kernels/row_encoder.h" #include "arrow/compute/kernels/test_util.h" -#include "arrow/testing/generator.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/matchers.h" #include "arrow/testing/random.h" @@ -85,7 +84,7 @@ Result MakeBatchesFromNumString( batches.schema = schema; int n_fields = schema->num_fields(); for (auto num_batch : num_batches.batches) { - Datum two(ConstantArrayGenerator::Int32(num_batch.length, 2)); + Datum two(Int32Scalar(2)); std::vector values; for (int i = 0; i < n_fields; i++) { auto type = schema->field(i)->type();