diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index e262a254a2e..28765567514 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -39,6 +39,7 @@ #include "arrow/result.h" #include "arrow/status.h" #include "arrow/type_traits.h" +#include "arrow/util/bit_util.h" #include "arrow/util/checked_cast.h" #include "arrow/util/future.h" @@ -747,6 +748,7 @@ class CompositeReferenceTable { } switch (field_type->id()) { + ASOFJOIN_MATERIALIZE_CASE(BOOL) ASOFJOIN_MATERIALIZE_CASE(INT8) ASOFJOIN_MATERIALIZE_CASE(INT16) ASOFJOIN_MATERIALIZE_CASE(INT32) @@ -803,12 +805,25 @@ class CompositeReferenceTable { } template ::BuilderType> - enable_if_fixed_width_type static BuilderAppend( + enable_if_boolean static BuilderAppend( Builder& builder, const std::shared_ptr& source, row_index_t row) { if (source->IsNull(row)) { builder.UnsafeAppendNull(); return Status::OK(); } + builder.UnsafeAppend(bit_util::GetBit(source->template GetValues(1), row)); + return Status::OK(); + } + + template ::BuilderType> + enable_if_t::value && !is_boolean_type::value, + Status> static BuilderAppend(Builder& builder, + const std::shared_ptr& source, + row_index_t row) { + if (source->IsNull(row)) { + builder.UnsafeAppendNull(); + return Status::OK(); + } using CType = typename TypeTraits::CType; builder.UnsafeAppend(source->template GetValues(1)[row]); return Status::OK(); @@ -1097,6 +1112,7 @@ class AsofJoinNode : public ExecNode { static Status is_valid_data_field(const std::shared_ptr& field) { switch (field->type()->id()) { + case Type::BOOL: case Type::INT8: case Type::INT16: case Type::INT32: diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index 66ed873620d..bfda559271f 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -32,6 +32,7 @@ #include "arrow/compute/exec/util.h" #include "arrow/compute/kernels/row_encoder.h" #include "arrow/compute/kernels/test_util.h" +#include "arrow/testing/generator.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/matchers.h" #include "arrow/testing/random.h" @@ -73,8 +74,9 @@ Result MakeBatchesFromNumString( const std::vector& json_strings, int multiplicity = 1) { FieldVector num_fields; for (auto field : schema->fields()) { - num_fields.push_back( - is_base_binary_like(field->type()->id()) ? field->WithType(int64()) : field); + auto id = field->type()->id(); + bool adjust = id == Type::BOOL || is_base_binary_like(id); + num_fields.push_back(adjust ? field->WithType(int64()) : field); } auto num_schema = std::make_shared(num_fields, schema->endianness(), schema->metadata()); @@ -84,6 +86,7 @@ Result MakeBatchesFromNumString( batches.schema = schema; int n_fields = schema->num_fields(); for (auto num_batch : num_batches.batches) { + Datum two(ConstantArrayGenerator::Int32(num_batch.length, 2)); std::vector values; for (int i = 0; i < n_fields; i++) { auto type = schema->field(i)->type(); @@ -92,6 +95,18 @@ Result MakeBatchesFromNumString( ARROW_ASSIGN_OR_RAISE(Datum as_string, Cast(num_batch.values[i], utf8())); ARROW_ASSIGN_OR_RAISE(Datum as_type, Cast(as_string, type)); values.push_back(as_type); + } else if (Type::BOOL == type->id()) { + // the next 4 lines compute `as_bool` as `(bool)(x - 2*(x/2))`, i.e., the low bit + // of `x`. Here, `x` stands for `num_batch.values[i]`, which is an `int64` value. + // Taking the low bit is a somewhat arbitrary way of obtaining both `true` and + // `false` values from the `int64` values in the test data, in order to get good + // testing coverage. A simple cast to a Boolean value would not get good coverage + // because all positive values would be cast to `true`. + ARROW_ASSIGN_OR_RAISE(Datum div_two, Divide(num_batch.values[i], two)); + ARROW_ASSIGN_OR_RAISE(Datum rounded, Multiply(div_two, two)); + ARROW_ASSIGN_OR_RAISE(Datum low_bit, Subtract(num_batch.values[i], rounded)); + ARROW_ASSIGN_OR_RAISE(Datum as_bool, Cast(low_bit, type)); + values.push_back(as_bool); } else { values.push_back(num_batch.values[i]); } @@ -535,6 +550,7 @@ struct BasicTest { large_utf8(), binary(), large_binary(), + boolean(), int8(), int16(), int32(), @@ -559,7 +575,8 @@ struct BasicTest { // byte_width > 1 below allows fitting the tested data auto time_types = init_types( all_types, [](T& t) { return t->byte_width() > 1 && !is_floating(t->id()); }); - auto key_types = init_types(all_types, [](T& t) { return !is_floating(t->id()); }); + auto key_types = init_types( + all_types, [](T& t) { return !is_floating(t->id()) && t->id() != Type::BOOL; }); auto l_types = init_types(all_types, [](T& t) { return true; }); auto r0_types = init_types(all_types, [](T& t) { return t->byte_width() > 1; }); auto r1_types = init_types(all_types, [](T& t) { return t->byte_width() > 1; }); diff --git a/cpp/src/arrow/compute/exec/test_nodes.h b/cpp/src/arrow/compute/exec/test_nodes.h index a117df0c460..d8954ed27db 100644 --- a/cpp/src/arrow/compute/exec/test_nodes.h +++ b/cpp/src/arrow/compute/exec/test_nodes.h @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +#pragma once + #include #include "arrow/compute/exec/options.h"