-
Notifications
You must be signed in to change notification settings - Fork 4k
ARROW-10557: [C++] Add scalar string slicing/substring extract kernel #9000
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f8c03a4
8ef4526
7773b58
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -138,7 +138,10 @@ struct StringTransform { | |
| using offset_type = typename Type::offset_type; | ||
| using ArrayType = typename TypeTraits<Type>::ArrayType; | ||
|
|
||
| static int64_t MaxCodeunits(offset_type input_ncodeunits) { return input_ncodeunits; } | ||
| virtual int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits) { | ||
| return input_ncodeunits; | ||
| } | ||
|
|
||
| static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { | ||
| return Derived().Execute(ctx, batch, out); | ||
| } | ||
|
|
@@ -156,7 +159,8 @@ struct StringTransform { | |
| offset_type input_ncodeunits = input_boxed.total_values_length(); | ||
| offset_type input_nstrings = static_cast<offset_type>(input.length); | ||
|
|
||
| int64_t output_ncodeunits_max = Derived::MaxCodeunits(input_ncodeunits); | ||
| const int64_t output_ncodeunits_max = | ||
| MaxCodeunits(input_nstrings, input_ncodeunits); | ||
| if (output_ncodeunits_max > std::numeric_limits<offset_type>::max()) { | ||
| return Status::CapacityError( | ||
| "Result might not fit in a 32bit utf8 array, convert to large_utf8"); | ||
|
|
@@ -183,35 +187,36 @@ struct StringTransform { | |
| output_ncodeunits += encoded_nbytes; | ||
| output_string_offsets[i + 1] = output_ncodeunits; | ||
| } | ||
| DCHECK_LE(output_ncodeunits, output_ncodeunits_max); | ||
|
|
||
| // Trim the codepoint buffer, since we allocated too much | ||
| RETURN_NOT_OK(values_buffer->Resize(output_ncodeunits, /*shrink_to_fit=*/true)); | ||
| return values_buffer->Resize(output_ncodeunits, /*shrink_to_fit=*/true); | ||
| } else { | ||
| DCHECK_EQ(batch[0].kind(), Datum::SCALAR); | ||
| const auto& input = checked_cast<const BaseBinaryScalar&>(*batch[0].scalar()); | ||
| auto result = checked_pointer_cast<BaseBinaryScalar>(MakeNullScalar(out->type())); | ||
| if (input.is_valid) { | ||
| result->is_valid = true; | ||
| offset_type data_nbytes = static_cast<offset_type>(input.value->size()); | ||
| if (!input.is_valid) { | ||
| return Status::OK(); | ||
| } | ||
| auto* result = checked_cast<BaseBinaryScalar*>(out->scalar().get()); | ||
| result->is_valid = true; | ||
| offset_type data_nbytes = static_cast<offset_type>(input.value->size()); | ||
|
|
||
| int64_t output_ncodeunits_max = Derived::MaxCodeunits(data_nbytes); | ||
| if (output_ncodeunits_max > std::numeric_limits<offset_type>::max()) { | ||
| return Status::CapacityError( | ||
| "Result might not fit in a 32bit utf8 array, convert to large_utf8"); | ||
| } | ||
| ARROW_ASSIGN_OR_RAISE(auto value_buffer, ctx->Allocate(output_ncodeunits_max)); | ||
| result->value = value_buffer; | ||
| offset_type encoded_nbytes = 0; | ||
| if (ARROW_PREDICT_FALSE(!static_cast<Derived&>(*this).Transform( | ||
| input.value->data(), data_nbytes, value_buffer->mutable_data(), | ||
| &encoded_nbytes))) { | ||
| return Derived::InvalidStatus(); | ||
| } | ||
| RETURN_NOT_OK(value_buffer->Resize(encoded_nbytes, /*shrink_to_fit=*/true)); | ||
| int64_t output_ncodeunits_max = MaxCodeunits(1, data_nbytes); | ||
| if (output_ncodeunits_max > std::numeric_limits<offset_type>::max()) { | ||
| return Status::CapacityError( | ||
| "Result might not fit in a 32bit utf8 array, convert to large_utf8"); | ||
| } | ||
| out->value = result; | ||
| ARROW_ASSIGN_OR_RAISE(auto value_buffer, ctx->Allocate(output_ncodeunits_max)); | ||
| result->value = value_buffer; | ||
| offset_type encoded_nbytes = 0; | ||
| if (ARROW_PREDICT_FALSE(!static_cast<Derived&>(*this).Transform( | ||
| input.value->data(), data_nbytes, value_buffer->mutable_data(), | ||
| &encoded_nbytes))) { | ||
| return Derived::InvalidStatus(); | ||
| } | ||
| DCHECK_LE(encoded_nbytes, output_ncodeunits_max); | ||
| return value_buffer->Resize(encoded_nbytes, /*shrink_to_fit=*/true); | ||
| } | ||
|
|
||
| return Status::OK(); | ||
| } | ||
| }; | ||
|
|
||
|
|
@@ -234,7 +239,8 @@ struct StringTransformCodepoint : StringTransform<Type, Derived> { | |
| *output_written = static_cast<offset_type>(output - output_start); | ||
| return true; | ||
| } | ||
| static int64_t MaxCodeunits(offset_type input_ncodeunits) { | ||
|
|
||
| int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits) override { | ||
| // Section 5.18 of the Unicode spec claim that the number of codepoints for case | ||
| // mapping can grow by a factor of 3. This means grow by a factor of 3 in bytes | ||
| // However, since we don't support all casings (SpecialCasing.txt) the growth | ||
|
|
@@ -243,6 +249,7 @@ struct StringTransformCodepoint : StringTransform<Type, Derived> { | |
| // two code units (even) can grow to 3 code units. | ||
| return static_cast<int64_t>(input_ncodeunits) * 3 / 2; | ||
| } | ||
|
|
||
| Status Execute(KernelContext* ctx, const ExecBatch& batch, Datum* out) { | ||
| EnsureLookupTablesFilled(); | ||
| return Base::Execute(ctx, batch, out); | ||
|
|
@@ -758,6 +765,209 @@ void AddFindSubstring(FunctionRegistry* registry) { | |
| DCHECK_OK(registry->AddFunction(std::move(func))); | ||
| } | ||
|
|
||
| // Slicing | ||
|
|
||
| template <typename Type, typename Derived> | ||
| struct SliceBase : StringTransform<Type, Derived> { | ||
| using Base = StringTransform<Type, Derived>; | ||
| using offset_type = typename Base::offset_type; | ||
| using State = OptionsWrapper<SliceOptions>; | ||
|
|
||
| SliceOptions options; | ||
|
|
||
| explicit SliceBase(SliceOptions options) : options(options) {} | ||
|
|
||
| static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { | ||
| SliceOptions options = State::Get(ctx); | ||
| if (options.step == 0) { | ||
| return Status::Invalid("Slice step cannot be zero"); | ||
| } | ||
| return Derived(options).Execute(ctx, batch, out); | ||
| } | ||
| }; | ||
|
|
||
| #define PROPAGATE_FALSE(expr) \ | ||
| do { \ | ||
| if (ARROW_PREDICT_FALSE(!expr)) { \ | ||
| return false; \ | ||
| } \ | ||
| } while (0) | ||
|
|
||
| bool SliceCodeunitsTransform(const uint8_t* input, int64_t input_string_ncodeunits, | ||
| uint8_t* output, int64_t* output_written, | ||
| const SliceOptions& options) { | ||
| const uint8_t* begin = input; | ||
| const uint8_t* end = input + input_string_ncodeunits; | ||
| const uint8_t* begin_sliced = begin; | ||
| const uint8_t* end_sliced = end; | ||
|
|
||
| if (options.step >= 1) { | ||
| if (options.start >= 0) { | ||
| // start counting from the left | ||
| PROPAGATE_FALSE( | ||
| arrow::util::UTF8AdvanceCodepoints(begin, end, &begin_sliced, options.start)); | ||
| if (options.stop > options.start) { | ||
| // continue counting from begin_sliced | ||
| int64_t length = options.stop - options.start; | ||
| PROPAGATE_FALSE( | ||
| arrow::util::UTF8AdvanceCodepoints(begin_sliced, end, &end_sliced, length)); | ||
| } else if (options.stop < 0) { | ||
| // or from the end (but we will never need to < begin_sliced) | ||
| PROPAGATE_FALSE(arrow::util::UTF8AdvanceCodepointsReverse( | ||
| begin_sliced, end, &end_sliced, -options.stop)); | ||
| } else { | ||
| // zero length slice | ||
| *output_written = 0; | ||
| return true; | ||
| } | ||
| } else { | ||
| // start counting from the right | ||
| PROPAGATE_FALSE(arrow::util::UTF8AdvanceCodepointsReverse(begin, end, &begin_sliced, | ||
| -options.start)); | ||
| if (options.stop > 0) { | ||
| // continue counting from the left, we cannot start from begin_sliced because we | ||
| // don't know how many codepoints are between begin and begin_sliced | ||
|
||
| PROPAGATE_FALSE( | ||
| arrow::util::UTF8AdvanceCodepoints(begin, end, &end_sliced, options.stop)); | ||
| // and therefore we also needs this | ||
| if (end_sliced <= begin_sliced) { | ||
| // zero length slice | ||
| *output_written = 0; | ||
| return true; | ||
| } | ||
| } else if ((options.stop < 0) && (options.stop > options.start)) { | ||
| // stop is negative, but larger than start, so we count again from the right | ||
| // in some cases we can optimize this, depending on the shortest path (from end | ||
| // or begin_sliced), but begin_sliced and options.start can be 'out of sync', | ||
| // for instance when start=-100, when the string length is only 10. | ||
| PROPAGATE_FALSE(arrow::util::UTF8AdvanceCodepointsReverse( | ||
| begin_sliced, end, &end_sliced, -options.stop)); | ||
| } else { | ||
| // zero length slice | ||
| *output_written = 0; | ||
| return true; | ||
| } | ||
| } | ||
| DCHECK(begin_sliced <= end_sliced); | ||
| if (options.step == 1) { | ||
| // fast case, where we simply can finish with a memcpy | ||
| std::copy(begin_sliced, end_sliced, output); | ||
| *output_written = end_sliced - begin_sliced; | ||
| } else { | ||
| uint8_t* dest = output; | ||
| const uint8_t* i = begin_sliced; | ||
|
|
||
| while (i < end_sliced) { | ||
| uint32_t codepoint = 0; | ||
| // write a single codepoint | ||
| PROPAGATE_FALSE(arrow::util::UTF8Decode(&i, &codepoint)); | ||
| dest = arrow::util::UTF8Encode(dest, codepoint); | ||
| // and skip the remainder | ||
| int64_t skips = options.step - 1; | ||
| while ((skips--) && (i < end_sliced)) { | ||
| PROPAGATE_FALSE(arrow::util::UTF8Decode(&i, &codepoint)); | ||
| } | ||
| } | ||
| *output_written = dest - output; | ||
| } | ||
| return true; | ||
| } else { // step < 0 | ||
| // serious +1 -1 kung fu because now begin_slice and end_slice act like reverse | ||
| // iterators. | ||
|
|
||
| if (options.start >= 0) { | ||
| // +1 because begin_sliced acts as as the end of a reverse iterator | ||
| PROPAGATE_FALSE(arrow::util::UTF8AdvanceCodepoints(begin, end, &begin_sliced, | ||
| options.start + 1)); | ||
| // and make it point at the last codeunit of the previous codeunit | ||
| begin_sliced--; | ||
| } else { | ||
| // -1 because start=-1 means the last codeunit, which is 0 advances | ||
| PROPAGATE_FALSE(arrow::util::UTF8AdvanceCodepointsReverse(begin, end, &begin_sliced, | ||
| -options.start - 1)); | ||
| // and make it point at the last codeunit of the previous codeunit | ||
| begin_sliced--; | ||
| } | ||
| // similar to options.start | ||
| if (options.stop >= 0) { | ||
| PROPAGATE_FALSE( | ||
| arrow::util::UTF8AdvanceCodepoints(begin, end, &end_sliced, options.stop + 1)); | ||
| end_sliced--; | ||
| } else { | ||
| PROPAGATE_FALSE(arrow::util::UTF8AdvanceCodepointsReverse(begin, end, &end_sliced, | ||
| -options.stop - 1)); | ||
| end_sliced--; | ||
| } | ||
|
|
||
| uint8_t* dest = output; | ||
| const uint8_t* i = begin_sliced; | ||
|
|
||
| while (i > end_sliced) { | ||
| uint32_t codepoint = 0; | ||
| // write a single codepoint | ||
| PROPAGATE_FALSE(arrow::util::UTF8DecodeReverse(&i, &codepoint)); | ||
| dest = arrow::util::UTF8Encode(dest, codepoint); | ||
| // and skip the remainder | ||
| int64_t skips = -options.step - 1; | ||
| while ((skips--) && (i > end_sliced)) { | ||
| PROPAGATE_FALSE(arrow::util::UTF8DecodeReverse(&i, &codepoint)); | ||
| } | ||
| } | ||
| *output_written = dest - output; | ||
| return true; | ||
| } | ||
| } | ||
|
|
||
| #undef PROPAGATE_FALSE | ||
|
|
||
| template <typename Type> | ||
| struct SliceCodeunits : SliceBase<Type, SliceCodeunits<Type>> { | ||
| using Base = SliceBase<Type, SliceCodeunits<Type>>; | ||
| using offset_type = typename Base::offset_type; | ||
| using Base::Base; | ||
|
|
||
| int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits) override { | ||
| const SliceOptions& opt = this->options; | ||
| if ((opt.start >= 0) != (opt.stop >= 0)) { | ||
| // If start and stop don't have the same sign, we can't guess an upper bound | ||
| // on the resulting slice lengths, so return a worst case estimate. | ||
| return input_ncodeunits; | ||
| } | ||
| int64_t max_slice_codepoints = (opt.stop - opt.start + opt.step - 1) / opt.step; | ||
| // The maximum UTF8 byte size of a codepoint is 4 | ||
| return std::min(input_ncodeunits, | ||
| 4 * ninputs * std::max<int64_t>(0, max_slice_codepoints)); | ||
| } | ||
|
|
||
| bool Transform(const uint8_t* input, offset_type input_string_ncodeunits, | ||
| uint8_t* output, offset_type* output_written) { | ||
| int64_t output_written_64; | ||
| bool res = SliceCodeunitsTransform(input, input_string_ncodeunits, output, | ||
| &output_written_64, this->options); | ||
| *output_written = static_cast<offset_type>(output_written_64); | ||
| return res; | ||
| } | ||
| }; | ||
|
|
||
| const FunctionDoc utf8_slice_codeunits_doc( | ||
| "Slice string ", | ||
| ("For each string in `strings`, slice into a substring defined by\n" | ||
| "`start`, `stop`, `step`) as given by `SliceOptions` where `start` is inclusive\n" | ||
| "and `stop` is exclusive and are measured in codeunits. If step is negative, the\n" | ||
| "string will be advanced in reversed order. A `step` of zero is considered an\n" | ||
| "error.\n" | ||
| "Null inputs emit null."), | ||
| {"strings"}, "SliceOptions"); | ||
|
|
||
| void AddSlice(FunctionRegistry* registry) { | ||
pitrou marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| auto func = std::make_shared<ScalarFunction>("utf8_slice_codeunits", Arity::Unary(), | ||
| &utf8_slice_codeunits_doc); | ||
| using t32 = SliceCodeunits<StringType>; | ||
| using t64 = SliceCodeunits<LargeStringType>; | ||
| DCHECK_OK(func->AddKernel({utf8()}, utf8(), t32::Exec, t32::State::Init)); | ||
| DCHECK_OK(func->AddKernel({large_utf8()}, large_utf8(), t64::Exec, t64::State::Init)); | ||
| DCHECK_OK(registry->AddFunction(std::move(func))); | ||
| } | ||
| // IsAlpha/Digit etc | ||
|
|
||
| #ifdef ARROW_WITH_UTF8PROC | ||
|
|
@@ -2716,7 +2926,6 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) { | |
| AddUnaryStringPredicate<IsUpperUnicode>("utf8_is_upper", registry, &utf8_is_upper_doc); | ||
| #endif | ||
|
|
||
| AddSplit(registry); | ||
| AddBinaryLength(registry); | ||
| AddUtf8Length(registry); | ||
| AddMatchSubstring(registry); | ||
|
|
@@ -2730,6 +2939,8 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) { | |
| MemAllocation::NO_PREALLOCATE); | ||
| AddExtractRegex(registry); | ||
| #endif | ||
| AddSlice(registry); | ||
| AddSplit(registry); | ||
| AddStrptime(registry); | ||
| } | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.