Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion cpp/src/arrow/compute/exec/asof_join_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "arrow/result.h"
#include "arrow/status.h"
#include "arrow/type_traits.h"
#include "arrow/util/bit_util.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/future.h"

Expand Down Expand Up @@ -747,6 +748,7 @@ class CompositeReferenceTable {
}

switch (field_type->id()) {
ASOFJOIN_MATERIALIZE_CASE(BOOL)
ASOFJOIN_MATERIALIZE_CASE(INT8)
ASOFJOIN_MATERIALIZE_CASE(INT16)
ASOFJOIN_MATERIALIZE_CASE(INT32)
Expand Down Expand Up @@ -803,12 +805,25 @@ class CompositeReferenceTable {
}

template <class Type, class Builder = typename TypeTraits<Type>::BuilderType>
enable_if_fixed_width_type<Type, Status> static BuilderAppend(
enable_if_boolean<Type, Status> static BuilderAppend(
Builder& builder, const std::shared_ptr<ArrayData>& source, row_index_t row) {
if (source->IsNull(row)) {
builder.UnsafeAppendNull();
return Status::OK();
}
builder.UnsafeAppend(bit_util::GetBit(source->template GetValues<uint8_t>(1), row));
return Status::OK();
}

template <class Type, class Builder = typename TypeTraits<Type>::BuilderType>
enable_if_t<is_fixed_width_type<Type>::value && !is_boolean_type<Type>::value,
Status> static BuilderAppend(Builder& builder,
const std::shared_ptr<ArrayData>& source,
row_index_t row) {
if (source->IsNull(row)) {
builder.UnsafeAppendNull();
return Status::OK();
}
using CType = typename TypeTraits<Type>::CType;
builder.UnsafeAppend(source->template GetValues<CType>(1)[row]);
return Status::OK();
Expand Down Expand Up @@ -1097,6 +1112,7 @@ class AsofJoinNode : public ExecNode {

static Status is_valid_data_field(const std::shared_ptr<Field>& field) {
switch (field->type()->id()) {
case Type::BOOL:
case Type::INT8:
case Type::INT16:
case Type::INT32:
Expand Down
23 changes: 20 additions & 3 deletions cpp/src/arrow/compute/exec/asof_join_node_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "arrow/compute/exec/util.h"
#include "arrow/compute/kernels/row_encoder.h"
#include "arrow/compute/kernels/test_util.h"
#include "arrow/testing/generator.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/matchers.h"
#include "arrow/testing/random.h"
Expand Down Expand Up @@ -73,8 +74,9 @@ Result<BatchesWithSchema> MakeBatchesFromNumString(
const std::vector<std::string_view>& json_strings, int multiplicity = 1) {
FieldVector num_fields;
for (auto field : schema->fields()) {
num_fields.push_back(
is_base_binary_like(field->type()->id()) ? field->WithType(int64()) : field);
auto id = field->type()->id();
bool adjust = id == Type::BOOL || is_base_binary_like(id);
num_fields.push_back(adjust ? field->WithType(int64()) : field);
}
auto num_schema =
std::make_shared<Schema>(num_fields, schema->endianness(), schema->metadata());
Expand All @@ -84,6 +86,7 @@ Result<BatchesWithSchema> MakeBatchesFromNumString(
batches.schema = schema;
int n_fields = schema->num_fields();
for (auto num_batch : num_batches.batches) {
Datum two(ConstantArrayGenerator::Int32(num_batch.length, 2));
std::vector<Datum> values;
for (int i = 0; i < n_fields; i++) {
auto type = schema->field(i)->type();
Expand All @@ -92,6 +95,18 @@ Result<BatchesWithSchema> MakeBatchesFromNumString(
ARROW_ASSIGN_OR_RAISE(Datum as_string, Cast(num_batch.values[i], utf8()));
ARROW_ASSIGN_OR_RAISE(Datum as_type, Cast(as_string, type));
values.push_back(as_type);
} else if (Type::BOOL == type->id()) {
// the next 4 lines compute `as_bool` as `(bool)(x - 2*(x/2))`, i.e., the low bit
// of `x`. Here, `x` stands for `num_batch.values[i]`, which is an `int64` value.
// Taking the low bit is a somewhat arbitrary way of obtaining both `true` and
// `false` values from the `int64` values in the test data, in order to get good
// testing coverage. A simple cast to a Boolean value would not get good coverage
// because all positive values would be cast to `true`.
ARROW_ASSIGN_OR_RAISE(Datum div_two, Divide(num_batch.values[i], two));
ARROW_ASSIGN_OR_RAISE(Datum rounded, Multiply(div_two, two));
ARROW_ASSIGN_OR_RAISE(Datum low_bit, Subtract(num_batch.values[i], rounded));
ARROW_ASSIGN_OR_RAISE(Datum as_bool, Cast(low_bit, type));
values.push_back(as_bool);
} else {
values.push_back(num_batch.values[i]);
}
Expand Down Expand Up @@ -535,6 +550,7 @@ struct BasicTest {
large_utf8(),
binary(),
large_binary(),
boolean(),
int8(),
int16(),
int32(),
Expand All @@ -559,7 +575,8 @@ struct BasicTest {
// byte_width > 1 below allows fitting the tested data
auto time_types = init_types(
all_types, [](T& t) { return t->byte_width() > 1 && !is_floating(t->id()); });
auto key_types = init_types(all_types, [](T& t) { return !is_floating(t->id()); });
auto key_types = init_types(
all_types, [](T& t) { return !is_floating(t->id()) && t->id() != Type::BOOL; });
auto l_types = init_types(all_types, [](T& t) { return true; });
auto r0_types = init_types(all_types, [](T& t) { return t->byte_width() > 1; });
auto r1_types = init_types(all_types, [](T& t) { return t->byte_width() > 1; });
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/arrow/compute/exec/test_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
// specific language governing permissions and limitations
// under the License.

#pragma once

#include <string>

#include "arrow/compute/exec/options.h"
Expand Down