diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index aef652e9662..19bd71df44e 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -36,6 +36,7 @@ #include "arrow/result.h" #include "arrow/status.h" #include "arrow/type_traits.h" +#include "arrow/util/bit_util.h" #include "arrow/util/checked_cast.h" #include "arrow/util/future.h" @@ -608,6 +609,7 @@ class CompositeReferenceTable { } switch (field_type->id()) { + ASOFJOIN_MATERIALIZE_CASE(BOOL) ASOFJOIN_MATERIALIZE_CASE(INT8) ASOFJOIN_MATERIALIZE_CASE(INT16) ASOFJOIN_MATERIALIZE_CASE(INT32) @@ -662,7 +664,6 @@ class CompositeReferenceTable { void AddRecordBatchRef(const std::shared_ptr& ref) { if (!_ptr2ref.count((uintptr_t)ref.get())) _ptr2ref[(uintptr_t)ref.get()] = ref; } - template ::BuilderType> enable_if_fixed_width_type static BuilderAppend( Builder& builder, const std::shared_ptr& source, row_index_t row) { @@ -670,8 +671,13 @@ class CompositeReferenceTable { builder.UnsafeAppendNull(); return Status::OK(); } - using CType = typename TypeTraits::CType; - builder.UnsafeAppend(source->template GetValues(1)[row]); + + if constexpr (is_boolean_type::value) { + builder.UnsafeAppend(bit_util::GetBit(source->template GetValues(1), row)); + } else { + using CType = typename TypeTraits::CType; + builder.UnsafeAppend(source->template GetValues(1)[row]); + } return Status::OK(); } @@ -924,6 +930,7 @@ class AsofJoinNode : public ExecNode { static Status is_valid_data_field(const std::shared_ptr& field) { switch (field->type()->id()) { + case Type::BOOL: case Type::INT8: case Type::INT16: case Type::INT32: diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index d3fa6c32f47..31bc094c52e 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -72,8 +72,9 @@ Result MakeBatchesFromNumString( const std::vector& json_strings, int multiplicity = 1) { FieldVector num_fields; for (auto field : schema->fields()) { - num_fields.push_back( - is_base_binary_like(field->type()->id()) ? field->WithType(int64()) : field); + auto id = field->type()->id(); + bool adjust = id == Type::BOOL || is_base_binary_like(id); + num_fields.push_back(adjust ? field->WithType(int64()) : field); } auto num_schema = std::make_shared(num_fields, schema->endianness(), schema->metadata()); @@ -83,6 +84,7 @@ Result MakeBatchesFromNumString( batches.schema = schema; int n_fields = schema->num_fields(); for (auto num_batch : num_batches.batches) { + Datum two(Int32Scalar(2)); std::vector values; for (int i = 0; i < n_fields; i++) { auto type = schema->field(i)->type(); @@ -91,6 +93,18 @@ Result MakeBatchesFromNumString( ARROW_ASSIGN_OR_RAISE(Datum as_string, Cast(num_batch.values[i], utf8())); ARROW_ASSIGN_OR_RAISE(Datum as_type, Cast(as_string, type)); values.push_back(as_type); + } else if (Type::BOOL == type->id()) { + // the next 4 lines compute `as_bool` as `(bool)(x - 2*(x/2))`, i.e., the low bit + // of `x`. Here, `x` stands for `num_batch.values[i]`, which is an `int64` value. + // Taking the low bit is a somewhat arbitrary way of obtaining both `true` and + // `false` values from the `int64` values in the test data, in order to get good + // testing coverage. A simple cast to a Boolean value would not get good coverage + // because all positive values would be cast to `true`. + ARROW_ASSIGN_OR_RAISE(Datum div_two, Divide(num_batch.values[i], two)); + ARROW_ASSIGN_OR_RAISE(Datum rounded, Multiply(div_two, two)); + ARROW_ASSIGN_OR_RAISE(Datum low_bit, Subtract(num_batch.values[i], rounded)); + ARROW_ASSIGN_OR_RAISE(Datum as_bool, Cast(low_bit, type)); + values.push_back(as_bool); } else { values.push_back(num_batch.values[i]); } @@ -526,6 +540,7 @@ struct BasicTest { large_utf8(), binary(), large_binary(), + boolean(), int8(), int16(), int32(), @@ -550,7 +565,8 @@ struct BasicTest { // byte_width > 1 below allows fitting the tested data auto time_types = init_types( all_types, [](T& t) { return t->byte_width() > 1 && !is_floating(t->id()); }); - auto key_types = init_types(all_types, [](T& t) { return !is_floating(t->id()); }); + auto key_types = init_types( + all_types, [](T& t) { return !is_floating(t->id()) && t->id() != Type::BOOL; }); auto l_types = init_types(all_types, [](T& t) { return true; }); auto r0_types = init_types(all_types, [](T& t) { return t->byte_width() > 1; }); auto r1_types = init_types(all_types, [](T& t) { return t->byte_width() > 1; });