Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions c_glib/test/test-decimal128-data-type.rb
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ def test_type

def test_name
data_type = Arrow::Decimal128DataType.new(2, 0)
assert_equal("decimal", data_type.name)
assert_equal("decimal128", data_type.name)
end

def test_to_s
data_type = Arrow::Decimal128DataType.new(2, 0)
assert_equal("decimal(2, 0)", data_type.to_s)
assert_equal("decimal128(2, 0)", data_type.to_s)
end

def test_precision
Expand Down
9 changes: 9 additions & 0 deletions cpp/src/arrow/compute/kernels/codegen_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ std::vector<std::shared_ptr<DataType>> g_numeric_types;
std::vector<std::shared_ptr<DataType>> g_base_binary_types;
std::vector<std::shared_ptr<DataType>> g_temporal_types;
std::vector<std::shared_ptr<DataType>> g_primitive_types;
std::vector<Type::type> g_decimal_type_ids;
static std::once_flag codegen_static_initialized;

template <typename T>
Expand All @@ -71,6 +72,9 @@ static void InitStaticData() {
// Floating point types
g_floating_types = {float32(), float64()};

// Decimal types
g_decimal_type_ids = {Type::DECIMAL128, Type::DECIMAL256};

// Numeric types
Extend(g_int_types, &g_numeric_types);
Extend(g_floating_types, &g_numeric_types);
Expand Down Expand Up @@ -132,6 +136,11 @@ const std::vector<std::shared_ptr<DataType>>& FloatingPointTypes() {
return g_floating_types;
}

const std::vector<Type::type>& DecimalTypeIds() {
std::call_once(codegen_static_initialized, InitStaticData);
return g_decimal_type_ids;
}

const std::vector<TimeUnit::type>& AllTimeUnits() {
static std::vector<TimeUnit::type> units = {TimeUnit::SECOND, TimeUnit::MILLI,
TimeUnit::MICRO, TimeUnit::NANO};
Expand Down
32 changes: 32 additions & 0 deletions cpp/src/arrow/compute/kernels/codegen_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,16 @@ struct GetViewType<Decimal128Type> {
}
};

template <>
struct GetViewType<Decimal256Type> {
using T = Decimal256;
using PhysicalType = util::string_view;

static T LogicalValue(PhysicalType value) {
return Decimal256(reinterpret_cast<const uint8_t*>(value.data()));
}
};

template <typename Type, typename Enable = void>
struct GetOutputType;

Expand All @@ -206,6 +216,11 @@ struct GetOutputType<Decimal128Type> {
using T = Decimal128;
};

template <>
struct GetOutputType<Decimal256Type> {
using T = Decimal256;
};

// ----------------------------------------------------------------------
// Iteration / value access utilities

Expand Down Expand Up @@ -396,6 +411,7 @@ const std::vector<std::shared_ptr<DataType>>& SignedIntTypes();
const std::vector<std::shared_ptr<DataType>>& UnsignedIntTypes();
const std::vector<std::shared_ptr<DataType>>& IntTypes();
const std::vector<std::shared_ptr<DataType>>& FloatingPointTypes();
const std::vector<Type::type>& DecimalTypeIds();

ARROW_EXPORT
const std::vector<TimeUnit::type>& AllTimeUnits();
Expand Down Expand Up @@ -1185,6 +1201,22 @@ ArrayKernelExec GenerateTemporal(detail::GetTypeId get_id) {
}
}

// Generate a kernel given a templated functor for decimal types
//
// See "Numeric" above for description of the generator functor
template <template <typename...> class Generator, typename Type0, typename... Args>
ArrayKernelExec GenerateDecimal(detail::GetTypeId get_id) {
switch (get_id.id) {
case Type::DECIMAL128:
return Generator<Type0, Decimal128Type, Args...>::Exec;
case Type::DECIMAL256:
return Generator<Type0, Decimal256Type, Args...>::Exec;
default:
DCHECK(false);
return ExecFail;
}
}

// END of kernel generator-dispatchers
// ----------------------------------------------------------------------

Expand Down
79 changes: 55 additions & 24 deletions cpp/src/arrow/compute/kernels/vector_sort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,18 @@ namespace internal {
VISIT(FloatType) \
VISIT(DoubleType) \
VISIT(BinaryType) \
VISIT(LargeBinaryType)
VISIT(LargeBinaryType) \
VISIT(FixedSizeBinaryType) \
VISIT(Decimal128Type) \
VISIT(Decimal256Type)

namespace {

// The target chunk in a chunked array.
template <typename ArrayType>
struct ResolvedChunk {
using ViewType = decltype(std::declval<ArrayType>().GetView(0));
using V = GetViewType<typename ArrayType::TypeClass>;
using LogicalValueType = typename V::T;

// The target array in chunked array.
const ArrayType* array;
Expand All @@ -70,7 +74,7 @@ struct ResolvedChunk {

bool IsNull() const { return array->IsNull(index); }

ViewType GetView() const { return array->GetView(index); }
LogicalValueType Value() const { return V::LogicalValue(array->GetView(index)); }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to refactor/rename GetViewType, this was very confusing to read (not in any way that's your fault)

};

// ResolvedChunk specialization for untyped arrays when all is needed is null lookup
Expand Down Expand Up @@ -279,7 +283,7 @@ PartitionNullLikes(uint64_t* indices_begin, uint64_t* indices_end,
ChunkedArrayResolver resolver(arrays);
return partitioner(indices_begin, indices_end, [&](uint64_t ind) {
const auto chunk = resolver.Resolve<ArrayType>(ind);
return !std::isnan(chunk.GetView());
return !std::isnan(chunk.Value());
});
}

Expand Down Expand Up @@ -318,6 +322,8 @@ struct PartitionNthToIndices {
using ArrayType = typename TypeTraits<InType>::ArrayType;

static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
using GetView = GetViewType<InType>;

if (ctx->state() == nullptr) {
ctx->SetStatus(Status::Invalid("NthToIndices requires PartitionNthOptions"));
return;
Expand All @@ -343,7 +349,9 @@ struct PartitionNthToIndices {
if (nth_begin < nulls_begin) {
std::nth_element(out_begin, nth_begin, nulls_begin,
[&arr](uint64_t left, uint64_t right) {
return arr.GetView(left) < arr.GetView(right);
const auto lval = GetView::LogicalValue(arr.GetView(left));
const auto rval = GetView::LogicalValue(arr.GetView(right));
return lval < rval;
});
}
}
Expand All @@ -365,6 +373,7 @@ inline void VisitRawValuesInline(const ArrayType& values,
template <typename ArrowType>
class ArrayCompareSorter {
using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
using GetView = GetViewType<ArrowType>;

public:
// Returns where null starts.
Expand All @@ -377,14 +386,18 @@ class ArrayCompareSorter {
if (options.order == SortOrder::Ascending) {
std::stable_sort(
indices_begin, nulls_begin, [&values, &offset](uint64_t left, uint64_t right) {
return values.GetView(left - offset) < values.GetView(right - offset);
const auto lhs = GetView::LogicalValue(values.GetView(left - offset));
const auto rhs = GetView::LogicalValue(values.GetView(right - offset));
return lhs < rhs;
});
} else {
std::stable_sort(
indices_begin, nulls_begin, [&values, &offset](uint64_t left, uint64_t right) {
const auto lhs = GetView::LogicalValue(values.GetView(left - offset));
const auto rhs = GetView::LogicalValue(values.GetView(right - offset));
// We don't use 'left > right' here to reduce required operator.
// If we use 'right < left' here, '<' is only required.
return values.GetView(right - offset) < values.GetView(left - offset);
return rhs < lhs;
});
}
return nulls_begin;
Expand Down Expand Up @@ -542,8 +555,9 @@ struct ArraySorter<Type, enable_if_t<(is_integer_type<Type>::value &&
};

template <typename Type>
struct ArraySorter<Type, enable_if_t<is_floating_type<Type>::value ||
is_base_binary_type<Type>::value>> {
struct ArraySorter<
Type, enable_if_t<is_floating_type<Type>::value || is_base_binary_type<Type>::value ||
is_fixed_size_binary_type<Type>::value>> {
ArrayCompareSorter<Type> impl;
};

Expand Down Expand Up @@ -585,12 +599,21 @@ void AddSortingKernels(VectorKernel base, VectorFunction* func) {
base.exec = GenerateNumeric<ExecTemplate, UInt64Type>(*physical_type);
DCHECK_OK(func->AddKernel(base));
}
for (const auto id : DecimalTypeIds()) {
base.signature = KernelSignature::Make({InputType::Array(id)}, uint64());
base.exec = GenerateDecimal<ExecTemplate, UInt64Type>(id);
DCHECK_OK(func->AddKernel(base));
}
for (const auto& ty : BaseBinaryTypes()) {
auto physical_type = GetPhysicalType(ty);
base.signature = KernelSignature::Make({InputType::Array(ty)}, uint64());
base.exec = GenerateVarBinaryBase<ExecTemplate, UInt64Type>(*physical_type);
DCHECK_OK(func->AddKernel(base));
}
base.signature =
KernelSignature::Make({InputType::Array(Type::FIXED_SIZE_BINARY)}, uint64());
base.exec = ExecTemplate<UInt64Type, FixedSizeBinaryType>::Exec;
DCHECK_OK(func->AddKernel(base));
}

// ----------------------------------------------------------------------
Expand All @@ -617,15 +640,15 @@ class ChunkedArrayCompareSorter {
std::stable_sort(indices_begin, nulls_begin, [&](uint64_t left, uint64_t right) {
const auto chunk_left = resolver.Resolve<ArrayType>(left);
const auto chunk_right = resolver.Resolve<ArrayType>(right);
return chunk_left.GetView() < chunk_right.GetView();
return chunk_left.Value() < chunk_right.Value();
});
} else {
std::stable_sort(indices_begin, nulls_begin, [&](uint64_t left, uint64_t right) {
const auto chunk_left = resolver.Resolve<ArrayType>(left);
const auto chunk_right = resolver.Resolve<ArrayType>(right);
// We don't use 'left > right' here to reduce required operator.
// If we use 'right < left' here, '<' is only required.
return chunk_right.GetView() < chunk_left.GetView();
return chunk_right.Value() < chunk_left.Value();
});
}
return nulls_begin;
Expand Down Expand Up @@ -786,7 +809,7 @@ class ChunkedArraySorter : public TypeVisitor {
[&](uint64_t left, uint64_t right) {
const auto chunk_left = left_resolver.Resolve<ArrayType>(left);
const auto chunk_right = right_resolver.Resolve<ArrayType>(right);
return chunk_left.GetView() < chunk_right.GetView();
return chunk_left.Value() < chunk_right.Value();
});
} else {
std::merge(indices_begin, indices_middle, indices_middle, indices_end, temp_indices,
Expand All @@ -796,7 +819,7 @@ class ChunkedArraySorter : public TypeVisitor {
// We don't use 'left > right' here to reduce required
// operator. If we use 'right < left' here, '<' is only
// required.
return chunk_right.GetView() < chunk_left.GetView();
return chunk_right.Value() < chunk_left.Value();
});
}
// Copy back temp area into main buffer
Expand All @@ -822,14 +845,16 @@ class ChunkedArraySorter : public TypeVisitor {
template <typename ArrayType, typename Visitor>
void VisitConstantRanges(const ArrayType& array, uint64_t* indices_begin,
uint64_t* indices_end, Visitor&& visit) {
using GetView = GetViewType<typename ArrayType::TypeClass>;

if (indices_begin == indices_end) {
return;
}
auto range_start = indices_begin;
auto range_cur = range_start;
auto last_value = array.GetView(*range_cur);
auto last_value = GetView::LogicalValue(array.GetView(*range_cur));
while (++range_cur != indices_end) {
auto v = array.GetView(*range_cur);
auto v = GetView::LogicalValue(array.GetView(*range_cur));
if (v != last_value) {
visit(range_start, range_cur);
range_start = range_cur;
Expand Down Expand Up @@ -869,6 +894,8 @@ class ConcreteRecordBatchColumnSorter : public RecordBatchColumnSorter {
null_count_(array_.null_count()) {}

void SortRange(uint64_t* indices_begin, uint64_t* indices_end) {
using GetView = GetViewType<Type>;

constexpr int64_t offset = 0;
uint64_t* nulls_begin;
if (null_count_ == 0) {
Expand All @@ -889,14 +916,18 @@ class ConcreteRecordBatchColumnSorter : public RecordBatchColumnSorter {
if (order_ == SortOrder::Ascending) {
std::stable_sort(
indices_begin, null_likes_begin, [&](uint64_t left, uint64_t right) {
return array_.GetView(left - offset) < array_.GetView(right - offset);
const auto lhs = GetView::LogicalValue(array_.GetView(left - offset));
const auto rhs = GetView::LogicalValue(array_.GetView(right - offset));
return lhs < rhs;
});
} else {
std::stable_sort(
indices_begin, null_likes_begin, [&](uint64_t left, uint64_t right) {
// We don't use 'left > right' here to reduce required operator.
// If we use 'right < left' here, '<' is only required.
return array_.GetView(right - offset) < array_.GetView(left - offset);
const auto lhs = GetView::LogicalValue(array_.GetView(left - offset));
const auto rhs = GetView::LogicalValue(array_.GetView(right - offset));
return lhs > rhs;
});
}

Expand Down Expand Up @@ -1100,8 +1131,8 @@ class MultipleKeyComparator {
const ResolvedChunk<typename TypeTraits<Type>::ArrayType>& chunk_left,
const ResolvedChunk<typename TypeTraits<Type>::ArrayType>& chunk_right,
const SortOrder order) {
const auto left = chunk_left.GetView();
const auto right = chunk_right.GetView();
const auto left = chunk_left.Value();
const auto right = chunk_right.Value();
int32_t compared;
if (left == right) {
compared = 0;
Expand All @@ -1122,8 +1153,8 @@ class MultipleKeyComparator {
const ResolvedChunk<typename TypeTraits<Type>::ArrayType>& chunk_left,
const ResolvedChunk<typename TypeTraits<Type>::ArrayType>& chunk_right,
const SortOrder order) {
const auto left = chunk_left.GetView();
const auto right = chunk_right.GetView();
const auto left = chunk_left.Value();
const auto right = chunk_right.Value();
auto is_nan_left = std::isnan(left);
auto is_nan_right = std::isnan(right);
if (is_nan_left && is_nan_right) {
Expand Down Expand Up @@ -1439,8 +1470,8 @@ class MultipleKeyTableSorter : public TypeVisitor {
// Both values are never null nor NaN.
auto chunk_left = first_sort_key.GetChunk<ArrayType>(left);
auto chunk_right = first_sort_key.GetChunk<ArrayType>(right);
auto value_left = chunk_left.GetView();
auto value_right = chunk_right.GetView();
auto value_left = chunk_left.Value();
auto value_right = chunk_right.Value();
if (value_left == value_right) {
// If the left value equals to the right value,
// we need to compare the second and following
Expand Down Expand Up @@ -1502,7 +1533,7 @@ class MultipleKeyTableSorter : public TypeVisitor {
DCHECK_EQ(indices_end_ - nulls_begin, first_sort_key.null_count);
uint64_t* nans_begin = partitioner(indices_begin_, nulls_begin, [&](uint64_t index) {
const auto chunk = first_sort_key.GetChunk<ArrayType>(index);
return !std::isnan(chunk.GetView());
return !std::isnan(chunk.Value());
});
auto& comparator = comparator_;
// Sort all NaNs by the second and following sort keys.
Expand Down
Loading