Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions cpp/src/arrow/compute/api_scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,14 @@ struct ARROW_EXPORT TrimOptions : public FunctionOptions {
std::string characters;
};

struct ARROW_EXPORT SliceOptions : public FunctionOptions {
explicit SliceOptions(int64_t start, int64_t stop = std::numeric_limits<int64_t>::max(),
int64_t step = 1)
: start(start), stop(stop), step(step) {}

int64_t start, stop, step;
};

enum CompareOperator : int8_t {
EQUAL,
NOT_EQUAL,
Expand Down
263 changes: 237 additions & 26 deletions cpp/src/arrow/compute/kernels/scalar_string.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,10 @@ struct StringTransform {
using offset_type = typename Type::offset_type;
using ArrayType = typename TypeTraits<Type>::ArrayType;

static int64_t MaxCodeunits(offset_type input_ncodeunits) { return input_ncodeunits; }
virtual int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits) {
return input_ncodeunits;
}

static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
return Derived().Execute(ctx, batch, out);
}
Expand All @@ -156,7 +159,8 @@ struct StringTransform {
offset_type input_ncodeunits = input_boxed.total_values_length();
offset_type input_nstrings = static_cast<offset_type>(input.length);

int64_t output_ncodeunits_max = Derived::MaxCodeunits(input_ncodeunits);
const int64_t output_ncodeunits_max =
MaxCodeunits(input_nstrings, input_ncodeunits);
if (output_ncodeunits_max > std::numeric_limits<offset_type>::max()) {
return Status::CapacityError(
"Result might not fit in a 32bit utf8 array, convert to large_utf8");
Expand All @@ -183,35 +187,36 @@ struct StringTransform {
output_ncodeunits += encoded_nbytes;
output_string_offsets[i + 1] = output_ncodeunits;
}
DCHECK_LE(output_ncodeunits, output_ncodeunits_max);

// Trim the codepoint buffer, since we allocated too much
RETURN_NOT_OK(values_buffer->Resize(output_ncodeunits, /*shrink_to_fit=*/true));
return values_buffer->Resize(output_ncodeunits, /*shrink_to_fit=*/true);
} else {
DCHECK_EQ(batch[0].kind(), Datum::SCALAR);
const auto& input = checked_cast<const BaseBinaryScalar&>(*batch[0].scalar());
auto result = checked_pointer_cast<BaseBinaryScalar>(MakeNullScalar(out->type()));
if (input.is_valid) {
result->is_valid = true;
offset_type data_nbytes = static_cast<offset_type>(input.value->size());
if (!input.is_valid) {
return Status::OK();
}
auto* result = checked_cast<BaseBinaryScalar*>(out->scalar().get());
result->is_valid = true;
offset_type data_nbytes = static_cast<offset_type>(input.value->size());

int64_t output_ncodeunits_max = Derived::MaxCodeunits(data_nbytes);
if (output_ncodeunits_max > std::numeric_limits<offset_type>::max()) {
return Status::CapacityError(
"Result might not fit in a 32bit utf8 array, convert to large_utf8");
}
ARROW_ASSIGN_OR_RAISE(auto value_buffer, ctx->Allocate(output_ncodeunits_max));
result->value = value_buffer;
offset_type encoded_nbytes = 0;
if (ARROW_PREDICT_FALSE(!static_cast<Derived&>(*this).Transform(
input.value->data(), data_nbytes, value_buffer->mutable_data(),
&encoded_nbytes))) {
return Derived::InvalidStatus();
}
RETURN_NOT_OK(value_buffer->Resize(encoded_nbytes, /*shrink_to_fit=*/true));
int64_t output_ncodeunits_max = MaxCodeunits(1, data_nbytes);
if (output_ncodeunits_max > std::numeric_limits<offset_type>::max()) {
return Status::CapacityError(
"Result might not fit in a 32bit utf8 array, convert to large_utf8");
}
out->value = result;
ARROW_ASSIGN_OR_RAISE(auto value_buffer, ctx->Allocate(output_ncodeunits_max));
result->value = value_buffer;
offset_type encoded_nbytes = 0;
if (ARROW_PREDICT_FALSE(!static_cast<Derived&>(*this).Transform(
input.value->data(), data_nbytes, value_buffer->mutable_data(),
&encoded_nbytes))) {
return Derived::InvalidStatus();
}
DCHECK_LE(encoded_nbytes, output_ncodeunits_max);
return value_buffer->Resize(encoded_nbytes, /*shrink_to_fit=*/true);
}

return Status::OK();
}
};

Expand All @@ -234,7 +239,8 @@ struct StringTransformCodepoint : StringTransform<Type, Derived> {
*output_written = static_cast<offset_type>(output - output_start);
return true;
}
static int64_t MaxCodeunits(offset_type input_ncodeunits) {

int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits) override {
// Section 5.18 of the Unicode spec claim that the number of codepoints for case
// mapping can grow by a factor of 3. This means grow by a factor of 3 in bytes
// However, since we don't support all casings (SpecialCasing.txt) the growth
Expand All @@ -243,6 +249,7 @@ struct StringTransformCodepoint : StringTransform<Type, Derived> {
// two code units (even) can grow to 3 code units.
return static_cast<int64_t>(input_ncodeunits) * 3 / 2;
}

Status Execute(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
EnsureLookupTablesFilled();
return Base::Execute(ctx, batch, out);
Expand Down Expand Up @@ -758,6 +765,209 @@ void AddFindSubstring(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunction(std::move(func)));
}

// Slicing

template <typename Type, typename Derived>
struct SliceBase : StringTransform<Type, Derived> {
using Base = StringTransform<Type, Derived>;
using offset_type = typename Base::offset_type;
using State = OptionsWrapper<SliceOptions>;

SliceOptions options;

explicit SliceBase(SliceOptions options) : options(options) {}

static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
SliceOptions options = State::Get(ctx);
if (options.step == 0) {
return Status::Invalid("Slice step cannot be zero");
}
return Derived(options).Execute(ctx, batch, out);
}
};

#define PROPAGATE_FALSE(expr) \
do { \
if (ARROW_PREDICT_FALSE(!expr)) { \
return false; \
} \
} while (0)

bool SliceCodeunitsTransform(const uint8_t* input, int64_t input_string_ncodeunits,
uint8_t* output, int64_t* output_written,
const SliceOptions& options) {
const uint8_t* begin = input;
const uint8_t* end = input + input_string_ncodeunits;
const uint8_t* begin_sliced = begin;
const uint8_t* end_sliced = end;

if (options.step >= 1) {
if (options.start >= 0) {
// start counting from the left
PROPAGATE_FALSE(
arrow::util::UTF8AdvanceCodepoints(begin, end, &begin_sliced, options.start));
if (options.stop > options.start) {
// continue counting from begin_sliced
int64_t length = options.stop - options.start;
PROPAGATE_FALSE(
arrow::util::UTF8AdvanceCodepoints(begin_sliced, end, &end_sliced, length));
} else if (options.stop < 0) {
// or from the end (but we will never need to < begin_sliced)
PROPAGATE_FALSE(arrow::util::UTF8AdvanceCodepointsReverse(
begin_sliced, end, &end_sliced, -options.stop));
} else {
// zero length slice
*output_written = 0;
return true;
}
} else {
// start counting from the right
PROPAGATE_FALSE(arrow::util::UTF8AdvanceCodepointsReverse(begin, end, &begin_sliced,
-options.start));
if (options.stop > 0) {
// continue counting from the left, we cannot start from begin_sliced because we
// don't know how many codepoints are between begin and begin_sliced
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm... for step < 0, wouldn't it be more logical to first compute end_sliced? Presumably, you then can pass end_sliced as one of the boundaries for computing begin_sliced?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a question btw, you don't need to act on this if you think it's unnecessary.

PROPAGATE_FALSE(
arrow::util::UTF8AdvanceCodepoints(begin, end, &end_sliced, options.stop));
// and therefore we also needs this
if (end_sliced <= begin_sliced) {
// zero length slice
*output_written = 0;
return true;
}
} else if ((options.stop < 0) && (options.stop > options.start)) {
// stop is negative, but larger than start, so we count again from the right
// in some cases we can optimize this, depending on the shortest path (from end
// or begin_sliced), but begin_sliced and options.start can be 'out of sync',
// for instance when start=-100, when the string length is only 10.
PROPAGATE_FALSE(arrow::util::UTF8AdvanceCodepointsReverse(
begin_sliced, end, &end_sliced, -options.stop));
} else {
// zero length slice
*output_written = 0;
return true;
}
}
DCHECK(begin_sliced <= end_sliced);
if (options.step == 1) {
// fast case, where we simply can finish with a memcpy
std::copy(begin_sliced, end_sliced, output);
*output_written = end_sliced - begin_sliced;
} else {
uint8_t* dest = output;
const uint8_t* i = begin_sliced;

while (i < end_sliced) {
uint32_t codepoint = 0;
// write a single codepoint
PROPAGATE_FALSE(arrow::util::UTF8Decode(&i, &codepoint));
dest = arrow::util::UTF8Encode(dest, codepoint);
// and skip the remainder
int64_t skips = options.step - 1;
while ((skips--) && (i < end_sliced)) {
PROPAGATE_FALSE(arrow::util::UTF8Decode(&i, &codepoint));
}
}
*output_written = dest - output;
}
return true;
} else { // step < 0
// serious +1 -1 kung fu because now begin_slice and end_slice act like reverse
// iterators.

if (options.start >= 0) {
// +1 because begin_sliced acts as as the end of a reverse iterator
PROPAGATE_FALSE(arrow::util::UTF8AdvanceCodepoints(begin, end, &begin_sliced,
options.start + 1));
// and make it point at the last codeunit of the previous codeunit
begin_sliced--;
} else {
// -1 because start=-1 means the last codeunit, which is 0 advances
PROPAGATE_FALSE(arrow::util::UTF8AdvanceCodepointsReverse(begin, end, &begin_sliced,
-options.start - 1));
// and make it point at the last codeunit of the previous codeunit
begin_sliced--;
}
// similar to options.start
if (options.stop >= 0) {
PROPAGATE_FALSE(
arrow::util::UTF8AdvanceCodepoints(begin, end, &end_sliced, options.stop + 1));
end_sliced--;
} else {
PROPAGATE_FALSE(arrow::util::UTF8AdvanceCodepointsReverse(begin, end, &end_sliced,
-options.stop - 1));
end_sliced--;
}

uint8_t* dest = output;
const uint8_t* i = begin_sliced;

while (i > end_sliced) {
uint32_t codepoint = 0;
// write a single codepoint
PROPAGATE_FALSE(arrow::util::UTF8DecodeReverse(&i, &codepoint));
dest = arrow::util::UTF8Encode(dest, codepoint);
// and skip the remainder
int64_t skips = -options.step - 1;
while ((skips--) && (i > end_sliced)) {
PROPAGATE_FALSE(arrow::util::UTF8DecodeReverse(&i, &codepoint));
}
}
*output_written = dest - output;
return true;
}
}

#undef PROPAGATE_FALSE

template <typename Type>
struct SliceCodeunits : SliceBase<Type, SliceCodeunits<Type>> {
using Base = SliceBase<Type, SliceCodeunits<Type>>;
using offset_type = typename Base::offset_type;
using Base::Base;

int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits) override {
const SliceOptions& opt = this->options;
if ((opt.start >= 0) != (opt.stop >= 0)) {
// If start and stop don't have the same sign, we can't guess an upper bound
// on the resulting slice lengths, so return a worst case estimate.
return input_ncodeunits;
}
int64_t max_slice_codepoints = (opt.stop - opt.start + opt.step - 1) / opt.step;
// The maximum UTF8 byte size of a codepoint is 4
return std::min(input_ncodeunits,
4 * ninputs * std::max<int64_t>(0, max_slice_codepoints));
}

bool Transform(const uint8_t* input, offset_type input_string_ncodeunits,
uint8_t* output, offset_type* output_written) {
int64_t output_written_64;
bool res = SliceCodeunitsTransform(input, input_string_ncodeunits, output,
&output_written_64, this->options);
*output_written = static_cast<offset_type>(output_written_64);
return res;
}
};

const FunctionDoc utf8_slice_codeunits_doc(
"Slice string ",
("For each string in `strings`, slice into a substring defined by\n"
"`start`, `stop`, `step`) as given by `SliceOptions` where `start` is inclusive\n"
"and `stop` is exclusive and are measured in codeunits. If step is negative, the\n"
"string will be advanced in reversed order. A `step` of zero is considered an\n"
"error.\n"
"Null inputs emit null."),
{"strings"}, "SliceOptions");

void AddSlice(FunctionRegistry* registry) {
auto func = std::make_shared<ScalarFunction>("utf8_slice_codeunits", Arity::Unary(),
&utf8_slice_codeunits_doc);
using t32 = SliceCodeunits<StringType>;
using t64 = SliceCodeunits<LargeStringType>;
DCHECK_OK(func->AddKernel({utf8()}, utf8(), t32::Exec, t32::State::Init));
DCHECK_OK(func->AddKernel({large_utf8()}, large_utf8(), t64::Exec, t64::State::Init));
DCHECK_OK(registry->AddFunction(std::move(func)));
}
// IsAlpha/Digit etc

#ifdef ARROW_WITH_UTF8PROC
Expand Down Expand Up @@ -2716,7 +2926,6 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) {
AddUnaryStringPredicate<IsUpperUnicode>("utf8_is_upper", registry, &utf8_is_upper_doc);
#endif

AddSplit(registry);
AddBinaryLength(registry);
AddUtf8Length(registry);
AddMatchSubstring(registry);
Expand All @@ -2730,6 +2939,8 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) {
MemAllocation::NO_PREALLOCATE);
AddExtractRegex(registry);
#endif
AddSlice(registry);
AddSplit(registry);
AddStrptime(registry);
}

Expand Down
Loading