Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 181 additions & 32 deletions cpp/src/arrow/compute/cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <sstream>
#include <string>
#include <type_traits>
#include <utility>

#include "arrow/array.h"
#include "arrow/buffer.h"
Expand Down Expand Up @@ -68,6 +69,24 @@
namespace arrow {
namespace compute {

template <typename T>
inline const T* GetValuesAs(const ArrayData& data, int i) {
return reinterpret_cast<const T*>(data.buffers[i]->data()) + data.offset;
}

namespace {

void CopyData(const Array& input, ArrayData* output) {
auto in_data = input.data();
output->length = in_data->length;
output->null_count = input.null_count();
output->buffers = in_data->buffers;
output->offset = in_data->offset;
output->child_data = in_data->child_data;
}

} // namespace

// ----------------------------------------------------------------------
// Zero copy casts

Expand All @@ -77,7 +96,9 @@ struct is_zero_copy_cast {
};

template <typename O, typename I>
struct is_zero_copy_cast<O, I, typename std::enable_if<std::is_same<I, O>::value>::type> {
struct is_zero_copy_cast<
O, I, typename std::enable_if<std::is_same<I, O>::value &&
!std::is_base_of<ParametricType, O>::value>::type> {
static constexpr bool value = true;
};

Expand All @@ -102,10 +123,7 @@ template <typename O, typename I>
struct CastFunctor<O, I, typename std::enable_if<is_zero_copy_cast<O, I>::value>::type> {
void operator()(FunctionContext* ctx, const CastOptions& options, const Array& input,
ArrayData* output) {
auto in_data = input.data();
output->null_count = input.null_count();
output->buffers = in_data->buffers;
output->child_data = in_data->child_data;
CopyData(input, output);
}
};

Expand All @@ -119,6 +137,7 @@ struct CastFunctor<T, NullType, typename std::enable_if<
ArrayData* output) {
// Simply initialize data to 0
auto buf = output->buffers[1];
DCHECK_EQ(output->offset, 0);
memset(buf->mutable_data(), 0, buf->size());
}
};
Expand All @@ -139,12 +158,16 @@ struct CastFunctor<T, BooleanType,
void operator()(FunctionContext* ctx, const CastOptions& options, const Array& input,
ArrayData* output) {
using c_type = typename T::c_type;
const uint8_t* data = input.data()->buffers[1]->data();
auto out = reinterpret_cast<c_type*>(output->buffers[1]->mutable_data());
constexpr auto kOne = static_cast<c_type>(1);
constexpr auto kZero = static_cast<c_type>(0);

auto in_data = input.data();
internal::BitmapReader bit_reader(in_data->buffers[1]->data(), in_data->offset,
in_data->length);
auto out = reinterpret_cast<c_type*>(output->buffers[1]->mutable_data());
for (int64_t i = 0; i < input.length(); ++i) {
*out++ = BitUtil::GetBit(data, i) ? kOne : kZero;
*out++ = bit_reader.IsSet() ? kOne : kZero;
bit_reader.Next();
}
}
};
Expand Down Expand Up @@ -189,7 +212,9 @@ struct CastFunctor<O, I, typename std::enable_if<std::is_same<BooleanType, O>::v
void operator()(FunctionContext* ctx, const CastOptions& options, const Array& input,
ArrayData* output) {
using in_type = typename I::c_type;
auto in_data = reinterpret_cast<const in_type*>(input.data()->buffers[1]->data());
DCHECK_EQ(output->offset, 0);

const in_type* in_data = GetValuesAs<in_type>(*input.data(), 1);
uint8_t* out_data = reinterpret_cast<uint8_t*>(output->buffers[1]->mutable_data());
for (int64_t i = 0; i < input.length(); ++i) {
BitUtil::SetBitTo(out_data, i, (*in_data++) != 0);
Expand All @@ -204,27 +229,27 @@ struct CastFunctor<O, I,
ArrayData* output) {
using in_type = typename I::c_type;
using out_type = typename O::c_type;
DCHECK_EQ(output->offset, 0);

auto in_offset = input.offset();

const auto& input_buffers = input.data()->buffers;

auto in_data = reinterpret_cast<const in_type*>(input_buffers[1]->data()) + in_offset;
const in_type* in_data = GetValuesAs<in_type>(*input.data(), 1);
auto out_data = reinterpret_cast<out_type*>(output->buffers[1]->mutable_data());

if (!options.allow_int_overflow) {
constexpr in_type kMax = static_cast<in_type>(std::numeric_limits<out_type>::max());
constexpr in_type kMin = static_cast<in_type>(std::numeric_limits<out_type>::min());

if (input.null_count() > 0) {
const uint8_t* is_valid = input_buffers[0]->data();
int64_t is_valid_offset = in_offset;
internal::BitmapReader is_valid_reader(input.data()->buffers[0]->data(),
in_offset, input.length());
for (int64_t i = 0; i < input.length(); ++i) {
if (ARROW_PREDICT_FALSE(BitUtil::GetBit(is_valid, is_valid_offset++) &&
if (ARROW_PREDICT_FALSE(is_valid_reader.IsSet() &&
(*in_data > kMax || *in_data < kMin))) {
ctx->SetStatus(Status::Invalid("Integer value out of bounds"));
}
*out_data++ = static_cast<out_type>(*in_data++);
is_valid_reader.Next();
}
} else {
for (int64_t i = 0; i < input.length(); ++i) {
Expand All @@ -251,14 +276,133 @@ struct CastFunctor<O, I,
using in_type = typename I::c_type;
using out_type = typename O::c_type;

auto in_data = reinterpret_cast<const in_type*>(input.data()->buffers[1]->data());
const in_type* in_data = GetValuesAs<in_type>(*input.data(), 1);
auto out_data = reinterpret_cast<out_type*>(output->buffers[1]->mutable_data());
for (int64_t i = 0; i < input.length(); ++i) {
*out_data++ = static_cast<out_type>(*in_data++);
}
}
};

// ----------------------------------------------------------------------
// From one timestamp to another

template <typename in_type, typename out_type>
inline void ShiftTime(FunctionContext* ctx, const CastOptions& options,
const bool is_multiply, const int64_t factor, const Array& input,
ArrayData* output) {
const in_type* in_data = GetValuesAs<in_type>(*input.data(), 1);
auto out_data = reinterpret_cast<out_type*>(output->buffers[1]->mutable_data());

if (is_multiply) {
for (int64_t i = 0; i < input.length(); i++) {
out_data[i] = static_cast<out_type>(in_data[i] * factor);
}
} else {
if (options.allow_time_truncate) {
for (int64_t i = 0; i < input.length(); i++) {
out_data[i] = static_cast<out_type>(in_data[i] / factor);
}
} else {
for (int64_t i = 0; i < input.length(); i++) {
out_data[i] = static_cast<out_type>(in_data[i] / factor);
if (input.IsValid(i) && (out_data[i] * factor != in_data[i])) {
std::stringstream ss;
ss << "Casting from " << input.type()->ToString() << " to "
<< output->type->ToString() << " would lose data: " << in_data[i];
ctx->SetStatus(Status::Invalid(ss.str()));
break;
}
}
}
}
}

namespace {

// {is_multiply, factor}
const std::pair<bool, int64_t> kTimeConversionTable[4][4] = {
{{true, 1}, {true, 1000}, {true, 1000000}, {true, 1000000000L}}, // SECOND
{{false, 1000}, {true, 1}, {true, 1000}, {true, 1000000}}, // MILLI
{{false, 1000000}, {false, 1000}, {true, 1}, {true, 1000}}, // MICRO
{{false, 1000000000L}, {false, 1000000}, {false, 1000}, {true, 1}}, // NANO
};

} // namespace

template <>
struct CastFunctor<TimestampType, TimestampType> {
void operator()(FunctionContext* ctx, const CastOptions& options, const Array& input,
ArrayData* output) {
// If units are the same, zero copy, otherwise convert
const auto& in_type = static_cast<const TimestampType&>(*input.type());
const auto& out_type = static_cast<const TimestampType&>(*output->type);

if (in_type.unit() == out_type.unit()) {
CopyData(input, output);
return;
}

std::pair<bool, int64_t> conversion =
kTimeConversionTable[static_cast<int>(in_type.unit())]
[static_cast<int>(out_type.unit())];

ShiftTime<int64_t, int64_t>(ctx, options, conversion.first, conversion.second, input,
output);
}
};

// ----------------------------------------------------------------------
// From one time32 or time64 to another

template <typename O, typename I>
struct CastFunctor<O, I,
typename std::enable_if<std::is_base_of<TimeType, I>::value &&
std::is_base_of<TimeType, O>::value>::type> {
void operator()(FunctionContext* ctx, const CastOptions& options, const Array& input,
ArrayData* output) {
using in_t = typename I::c_type;
using out_t = typename O::c_type;

// If units are the same, zero copy, otherwise convert
const auto& in_type = static_cast<const I&>(*input.type());
const auto& out_type = static_cast<const O&>(*output->type);

if (in_type.unit() == out_type.unit()) {
CopyData(input, output);
return;
}

std::pair<bool, int64_t> conversion =
kTimeConversionTable[static_cast<int>(in_type.unit())]
[static_cast<int>(out_type.unit())];

ShiftTime<in_t, out_t>(ctx, options, conversion.first, conversion.second, input,
output);
}
};

// ----------------------------------------------------------------------
// Between date32 and date64

constexpr int64_t kMillisecondsInDay = 86400000;

template <>
struct CastFunctor<Date64Type, Date32Type> {
void operator()(FunctionContext* ctx, const CastOptions& options, const Array& input,
ArrayData* output) {
ShiftTime<int32_t, int64_t>(ctx, options, true, kMillisecondsInDay, input, output);
}
};

template <>
struct CastFunctor<Date32Type, Date64Type> {
void operator()(FunctionContext* ctx, const CastOptions& options, const Array& input,
ArrayData* output) {
ShiftTime<int64_t, int32_t>(ctx, options, false, kMillisecondsInDay, input, output);
}
};

// ----------------------------------------------------------------------
// Dictionary to other things

Expand All @@ -271,9 +415,8 @@ void UnpackFixedSizeBinaryDictionary(FunctionContext* ctx, const Array& indices,
internal::BitmapReader valid_bits_reader(indices.null_bitmap_data(), indices.offset(),
indices.length());

const index_c_type* in =
reinterpret_cast<const index_c_type*>(indices.data()->buffers[1]->data()) +
indices.offset();
const index_c_type* in = GetValuesAs<index_c_type>(*indices.data(), 1);

uint8_t* out = output->buffers[1]->mutable_data();
int32_t byte_width =
static_cast<const FixedSizeBinaryType&>(*output->type).byte_width();
Expand Down Expand Up @@ -336,9 +479,7 @@ Status UnpackBinaryDictionary(FunctionContext* ctx, const Array& indices,
internal::BitmapReader valid_bits_reader(indices.null_bitmap_data(), indices.offset(),
indices.length());

const index_c_type* in =
reinterpret_cast<const index_c_type*>(indices.data()->buffers[1]->data()) +
indices.offset();
const index_c_type* in = GetValuesAs<index_c_type>(*indices.data(), 1);
for (int64_t i = 0; i < indices.length(); ++i) {
if (valid_bits_reader.IsSet()) {
int32_t length;
Expand Down Expand Up @@ -409,9 +550,7 @@ void UnpackPrimitiveDictionary(const Array& indices, const c_type* dictionary,
internal::BitmapReader valid_bits_reader(indices.null_bitmap_data(), indices.offset(),
indices.length());

const index_c_type* in =
reinterpret_cast<const index_c_type*>(indices.data()->buffers[1]->data()) +
indices.offset();
const index_c_type* in = GetValuesAs<index_c_type>(*indices.data(), 1);
for (int64_t i = 0; i < indices.length(); ++i) {
if (valid_bits_reader.IsSet()) {
out[i] = dictionary[in[i]];
Expand All @@ -436,9 +575,8 @@ struct CastFunctor<T, DictionaryType,
DCHECK(values_type.Equals(*output->type))
<< "Dictionary type: " << values_type << " target type: " << (*output->type);

auto dictionary =
reinterpret_cast<const c_type*>(type.dictionary()->data()->buffers[1]->data()) +
type.dictionary()->offset();
const c_type* dictionary = GetValuesAs<c_type>(*type.dictionary()->data(), 1);

auto out = reinterpret_cast<c_type*>(output->buffers[1]->mutable_data());
const Array& indices = *dict_array.indices();
switch (indices.type()->id()) {
Expand Down Expand Up @@ -481,6 +619,9 @@ static Status AllocateIfNotPreallocated(FunctionContext* ctx, const Array& input
int64_t bitmap_size = BitUtil::BytesForBits(length);
RETURN_NOT_OK(ctx->Allocate(bitmap_size, &validity_bitmap));
memset(validity_bitmap->mutable_data(), 0, bitmap_size);
} else if (input.offset() != 0) {
RETURN_NOT_OK(CopyBitmap(ctx->memory_pool(), validity_bitmap->data(), input.offset(),
length, &validity_bitmap));
}

if (out->buffers.size() == 2) {
Expand Down Expand Up @@ -598,13 +739,21 @@ class CastKernel : public UnaryKernel {
FN(Int64Type, Time64Type); \
FN(Int64Type, Date64Type);

#define DATE32_CASES(FN, IN_TYPE) FN(Date32Type, Date32Type);
#define DATE32_CASES(FN, IN_TYPE) \
FN(Date32Type, Date32Type); \
FN(Date32Type, Date64Type);

#define DATE64_CASES(FN, IN_TYPE) FN(Date64Type, Date64Type);
#define DATE64_CASES(FN, IN_TYPE) \
FN(Date64Type, Date64Type); \
FN(Date64Type, Date32Type);

#define TIME32_CASES(FN, IN_TYPE) FN(Time32Type, Time32Type);
#define TIME32_CASES(FN, IN_TYPE) \
FN(Time32Type, Time32Type); \
FN(Time32Type, Time64Type);

#define TIME64_CASES(FN, IN_TYPE) FN(Time64Type, Time64Type);
#define TIME64_CASES(FN, IN_TYPE) \
FN(Time64Type, Time32Type); \
FN(Time64Type, Time64Type);

#define TIMESTAMP_CASES(FN, IN_TYPE) FN(TimestampType, TimestampType);

Expand Down
3 changes: 2 additions & 1 deletion cpp/src/arrow/compute/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@ class FunctionContext;
class UnaryKernel;

struct CastOptions {
CastOptions() : allow_int_overflow(false) {}
CastOptions() : allow_int_overflow(false), allow_time_truncate(false) {}

bool allow_int_overflow;
bool allow_time_truncate;
};

/// \since 0.7.0
Expand Down
Loading