diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index fc9dc35490e..62d52d245fb 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -49,6 +49,25 @@ struct ARROW_EXPORT MatchSubstringOptions : public FunctionOptions { std::string pattern; }; +struct ARROW_EXPORT SplitOptions : public FunctionOptions { + explicit SplitOptions(int64_t max_splits = -1, bool reverse = false) + : max_splits(max_splits), reverse(reverse) {} + + /// Maximum number of splits allowed, or unlimited when -1 + int64_t max_splits; + /// Start splitting from the end of the string (only relevant when max_splits != -1) + bool reverse; +}; + +struct ARROW_EXPORT SplitPatternOptions : public SplitOptions { + explicit SplitPatternOptions(std::string pattern, int64_t max_splits = -1, + bool reverse = false) + : SplitOptions(max_splits, reverse), pattern(std::move(pattern)) {} + + /// The exact substring to look for inside input values. + std::string pattern; +}; + /// Options for IsIn and IndexIn functions struct ARROW_EXPORT SetLookupOptions : public FunctionOptions { explicit SetLookupOptions(Datum value_set, bool skip_nulls) diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index f92343d9ed7..00ab80ba23d 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -23,6 +23,9 @@ #include #endif +#include "arrow/array/builder_binary.h" +#include "arrow/array/builder_nested.h" +#include "arrow/buffer_builder.h" #include "arrow/compute/api_scalar.h" #include "arrow/compute/kernels/common.h" #include "arrow/util/utf8.h" @@ -815,6 +818,394 @@ struct IsUpperAscii : CharacterPredicateAscii { } }; +// splitting + +template +struct SplitBaseTransform { + using string_offset_type = typename Type::offset_type; + using list_offset_type = typename ListType::offset_type; + using ArrayType = typename TypeTraits::ArrayType; + using ArrayListType = typename TypeTraits::ArrayType; + using ListScalarType = typename TypeTraits::ScalarType; + using ScalarType = typename TypeTraits::ScalarType; + using BuilderType = typename TypeTraits::BuilderType; + using ListOffsetsBuilderType = TypedBufferBuilder; + using State = OptionsWrapper; + + std::vector parts; + Options options; + + explicit SplitBaseTransform(Options options) : options(options) {} + + Status Split(const util::string_view& s, BuilderType* builder) { + const uint8_t* begin = reinterpret_cast(s.data()); + const uint8_t* end = begin + s.length(); + + int64_t max_splits = options.max_splits; + // if there is no max splits, reversing does not make sense (and is probably less + // efficient), but is useful for testing + if (options.reverse) { + // note that i points 1 further than the 'current' + const uint8_t* i = end; + // we will record the parts in reverse order + parts.clear(); + if (max_splits > -1) { + parts.reserve(max_splits + 1); + } + while (max_splits != 0) { + const uint8_t *separator_begin, *separator_end; + // find with whatever algo the part we will 'cut out' + if (static_cast(*this).FindReverse(begin, i, &separator_begin, + &separator_end, options)) { + parts.emplace_back(reinterpret_cast(separator_end), + i - separator_end); + i = separator_begin; + max_splits--; + } else { + // if we cannot find a separator, we're done + break; + } + } + parts.emplace_back(reinterpret_cast(begin), i - begin); + // now we do the copying + for (auto it = parts.rbegin(); it != parts.rend(); ++it) { + RETURN_NOT_OK(builder->Append(*it)); + } + } else { + const uint8_t* i = begin; + while (max_splits != 0) { + const uint8_t *separator_begin, *separator_end; + // find with whatever algo the part we will 'cut out' + if (static_cast(*this).Find(i, end, &separator_begin, &separator_end, + options)) { + // the part till the beginning of the 'cut' + RETURN_NOT_OK( + builder->Append(i, static_cast(separator_begin - i))); + i = separator_end; + max_splits--; + } else { + // if we cannot find a separator, we're done + break; + } + } + // trailing part + RETURN_NOT_OK(builder->Append(i, static_cast(end - i))); + } + return Status::OK(); + } + + static Status CheckOptions(const Options& options) { return Status::OK(); } + + static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + Options options = State::Get(ctx); + Derived splitter(options); // we make an instance to reuse the parts vectors + splitter.Split(ctx, batch, out); + } + + void Split(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + EnsureLookupTablesFilled(); // only needed for unicode + KERNEL_RETURN_IF_ERROR(ctx, Derived::CheckOptions(options)); + + if (batch[0].kind() == Datum::ARRAY) { + const ArrayData& input = *batch[0].array(); + ArrayType input_boxed(batch[0].array()); + + string_offset_type input_nstrings = static_cast(input.length); + + BuilderType builder(input.type, ctx->memory_pool()); + // a slight overestimate of the data needed + KERNEL_RETURN_IF_ERROR(ctx, builder.ReserveData(input_boxed.total_values_length())); + // the minimum amount of strings needed + KERNEL_RETURN_IF_ERROR(ctx, builder.Resize(input.length)); + + // ideally we do not allocate this, see + // https://issues.apache.org/jira/browse/ARROW-10207 + ListOffsetsBuilderType list_offsets_builder(ctx->memory_pool()); + KERNEL_RETURN_IF_ERROR(ctx, list_offsets_builder.Resize(input_nstrings)); + ArrayData* output_list = out->mutable_array(); + // // we use the same null values + output_list->buffers[0] = input.buffers[0]; + // initial value + KERNEL_RETURN_IF_ERROR( + ctx, list_offsets_builder.Append(static_cast(0))); + KERNEL_RETURN_IF_ERROR( + ctx, + VisitArrayDataInline( + input, + [&](util::string_view s) { + RETURN_NOT_OK(Split(s, &builder)); + if (ARROW_PREDICT_FALSE(builder.length() > + std::numeric_limits::max())) { + return Status::CapacityError("List offset does not fit into 32 bit"); + } + RETURN_NOT_OK(list_offsets_builder.Append( + static_cast(builder.length()))); + return Status::OK(); + }, + [&]() { + // null value is already taken from input + RETURN_NOT_OK(list_offsets_builder.Append( + static_cast(builder.length()))); + return Status::OK(); + })); + // assign list indices + KERNEL_RETURN_IF_ERROR(ctx, list_offsets_builder.Finish(&output_list->buffers[1])); + // assign list child data + std::shared_ptr string_array; + KERNEL_RETURN_IF_ERROR(ctx, builder.Finish(&string_array)); + output_list->child_data.push_back(string_array->data()); + + } else { + const auto& input = checked_cast(*batch[0].scalar()); + auto result = checked_pointer_cast(MakeNullScalar(out->type())); + if (input.is_valid) { + result->is_valid = true; + BuilderType builder(input.type, ctx->memory_pool()); + util::string_view s = static_cast(*input.value); + KERNEL_RETURN_IF_ERROR(ctx, Split(s, &builder)); + KERNEL_RETURN_IF_ERROR(ctx, builder.Finish(&result->value)); + } + out->value = result; + } + } +}; + +template +struct SplitPatternTransform : SplitBaseTransform> { + using Base = SplitBaseTransform>; + using ArrayType = typename TypeTraits::ArrayType; + using ScalarType = typename TypeTraits::ScalarType; + using string_offset_type = typename Type::offset_type; + using Base::Base; + + static Status CheckOptions(const SplitPatternOptions& options) { + if (options.pattern.length() == 0) { + return Status::Invalid("Empty separator"); + } + return Status::OK(); + } + static bool Find(const uint8_t* begin, const uint8_t* end, + const uint8_t** separator_begin, const uint8_t** separator_end, + const SplitPatternOptions& options) { + const uint8_t* pattern = reinterpret_cast(options.pattern.c_str()); + const int64_t pattern_length = options.pattern.length(); + const uint8_t* i = begin; + // this is O(n*m) complexity, we could use the Knuth-Morris-Pratt algorithm used in + // the match kernel + while ((i + pattern_length <= end)) { + i = std::search(i, end, pattern, pattern + pattern_length); + if (i != end) { + *separator_begin = i; + *separator_end = i + pattern_length; + return true; + } + } + return false; + } + static bool FindReverse(const uint8_t* begin, const uint8_t* end, + const uint8_t** separator_begin, const uint8_t** separator_end, + const SplitPatternOptions& options) { + const uint8_t* pattern = reinterpret_cast(options.pattern.c_str()); + const int64_t pattern_length = options.pattern.length(); + // this is O(n*m) complexity, we could use the Knuth-Morris-Pratt algorithm used in + // the match kernel + std::reverse_iterator ri(end); + std::reverse_iterator rend(begin); + std::reverse_iterator pattern_rbegin(pattern + pattern_length); + std::reverse_iterator pattern_rend(pattern); + while (begin <= ri.base() - pattern_length) { + ri = std::search(ri, rend, pattern_rbegin, pattern_rend); + if (ri != rend) { + *separator_begin = ri.base() - pattern_length; + *separator_end = ri.base(); + return true; + } + } + return false; + } +}; + +const FunctionDoc split_pattern_doc( + "Split string according to separator", + ("Split each string according to the exact `pattern` defined in\n" + "SplitPatternOptions. The output for each string input is a list\n" + "of strings.\n" + "\n" + "The maximum number of splits and direction of splitting\n" + "(forward, reverse) can optionally be defined in SplitPatternOptions."), + {"strings"}, "SplitPatternOptions"); + +const FunctionDoc ascii_split_whitespace_doc( + "Split string according to any ASCII whitespace", + ("Split each string according any non-zero length sequence of ASCII\n" + "whitespace characters. The output for each string input is a list\n" + "of strings.\n" + "\n" + "The maximum number of splits and direction of splitting\n" + "(forward, reverse) can optionally be defined in SplitOptions."), + {"strings"}, "SplitOptions"); + +const FunctionDoc utf8_split_whitespace_doc( + "Split string according to any Unicode whitespace", + ("Split each string according any non-zero length sequence of Unicode\n" + "whitespace characters. The output for each string input is a list\n" + "of strings.\n" + "\n" + "The maximum number of splits and direction of splitting\n" + "(forward, reverse) can optionally be defined in SplitOptions."), + {"strings"}, "SplitOptions"); + +void AddSplitPattern(FunctionRegistry* registry) { + auto func = std::make_shared("split_pattern", Arity::Unary(), + &split_pattern_doc); + using t32 = SplitPatternTransform; + using t64 = SplitPatternTransform; + DCHECK_OK(func->AddKernel({utf8()}, {list(utf8())}, t32::Exec, t32::State::Init)); + DCHECK_OK( + func->AddKernel({large_utf8()}, {list(large_utf8())}, t64::Exec, t64::State::Init)); + DCHECK_OK(registry->AddFunction(std::move(func))); +} + +template +struct SplitWhitespaceAsciiTransform + : SplitBaseTransform> { + using Base = SplitBaseTransform>; + using ArrayType = typename TypeTraits::ArrayType; + using ScalarType = typename TypeTraits::ScalarType; + using string_offset_type = typename Type::offset_type; + using Base::Base; + static bool Find(const uint8_t* begin, const uint8_t* end, + const uint8_t** separator_begin, const uint8_t** separator_end, + const SplitOptions& options) { + const uint8_t* i = begin; + while ((i < end)) { + if (IsSpaceCharacterAscii(*i)) { + *separator_begin = i; + do { + i++; + } while (IsSpaceCharacterAscii(*i) && i < end); + *separator_end = i; + return true; + } + i++; + } + return false; + } + static bool FindReverse(const uint8_t* begin, const uint8_t* end, + const uint8_t** separator_begin, const uint8_t** separator_end, + const SplitOptions& options) { + const uint8_t* i = end - 1; + while ((i >= begin)) { + if (IsSpaceCharacterAscii(*i)) { + *separator_end = i + 1; + do { + i--; + } while (IsSpaceCharacterAscii(*i) && i >= begin); + *separator_begin = i + 1; + return true; + } + i--; + } + return false; + } +}; + +void AddSplitWhitespaceAscii(FunctionRegistry* registry) { + static const SplitOptions default_options{}; + auto func = + std::make_shared("ascii_split_whitespace", Arity::Unary(), + &ascii_split_whitespace_doc, &default_options); + using t32 = SplitWhitespaceAsciiTransform; + using t64 = SplitWhitespaceAsciiTransform; + DCHECK_OK(func->AddKernel({utf8()}, {list(utf8())}, t32::Exec, t32::State::Init)); + DCHECK_OK( + func->AddKernel({large_utf8()}, {list(large_utf8())}, t64::Exec, t64::State::Init)); + DCHECK_OK(registry->AddFunction(std::move(func))); +} + +#ifdef ARROW_WITH_UTF8PROC +template +struct SplitWhitespaceUtf8Transform + : SplitBaseTransform> { + using Base = SplitBaseTransform>; + using ArrayType = typename TypeTraits::ArrayType; + using string_offset_type = typename Type::offset_type; + using ScalarType = typename TypeTraits::ScalarType; + using Base::Base; + static bool Find(const uint8_t* begin, const uint8_t* end, + const uint8_t** separator_begin, const uint8_t** separator_end, + const SplitOptions& options) { + const uint8_t* i = begin; + while ((i < end)) { + uint32_t codepoint = 0; + *separator_begin = i; + if (ARROW_PREDICT_FALSE(!arrow::util::UTF8Decode(&i, &codepoint))) { + return false; + } + if (IsSpaceCharacterUnicode(codepoint)) { + do { + *separator_end = i; + if (ARROW_PREDICT_FALSE(!arrow::util::UTF8Decode(&i, &codepoint))) { + return false; + } + } while (IsSpaceCharacterUnicode(codepoint) && i < end); + return true; + } + } + return false; + } + static bool FindReverse(const uint8_t* begin, const uint8_t* end, + const uint8_t** separator_begin, const uint8_t** separator_end, + const SplitOptions& options) { + const uint8_t* i = end - 1; + while ((i >= begin)) { + uint32_t codepoint = 0; + *separator_end = i + 1; + if (ARROW_PREDICT_FALSE(!arrow::util::UTF8DecodeReverse(&i, &codepoint))) { + return false; + } + if (IsSpaceCharacterUnicode(codepoint)) { + do { + *separator_begin = i + 1; + if (ARROW_PREDICT_FALSE(!arrow::util::UTF8DecodeReverse(&i, &codepoint))) { + return false; + } + } while (IsSpaceCharacterUnicode(codepoint) && i >= begin); + return true; + } + } + return false; + } +}; + +void AddSplitWhitespaceUTF8(FunctionRegistry* registry) { + static const SplitOptions default_options{}; + auto func = + std::make_shared("utf8_split_whitespace", Arity::Unary(), + &utf8_split_whitespace_doc, &default_options); + using t32 = SplitWhitespaceUtf8Transform; + using t64 = SplitWhitespaceUtf8Transform; + DCHECK_OK(func->AddKernel({utf8()}, {list(utf8())}, t32::Exec, t32::State::Init)); + DCHECK_OK( + func->AddKernel({large_utf8()}, {list(large_utf8())}, t64::Exec, t64::State::Init)); + DCHECK_OK(registry->AddFunction(std::move(func))); +} +#endif + +void AddSplit(FunctionRegistry* registry) { + AddSplitPattern(registry); + AddSplitWhitespaceAscii(registry); +#ifdef ARROW_WITH_UTF8PROC + AddSplitWhitespaceUTF8(registry); +#endif +} + // ---------------------------------------------------------------------- // strptime string parsing @@ -1103,6 +1494,7 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) { AddUnaryStringPredicate("utf8_is_upper", registry, &utf8_is_upper_doc); #endif + AddSplit(registry); AddBinaryLength(registry); AddMatchSubstring(registry); AddStrptime(registry); diff --git a/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc b/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc index 0ad7f724c5f..4b77cf07bcf 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc @@ -66,6 +66,11 @@ static void MatchSubstring(benchmark::State& state) { UnaryStringBenchmark(state, "match_substring", &options); } +static void SplitPattern(benchmark::State& state) { + SplitPatternOptions options("a"); + UnaryStringBenchmark(state, "split_pattern", &options); +} + #ifdef ARROW_WITH_UTF8PROC static void Utf8Upper(benchmark::State& state) { UnaryStringBenchmark(state, "utf8_upper"); @@ -84,6 +89,7 @@ BENCHMARK(AsciiLower); BENCHMARK(AsciiUpper); BENCHMARK(IsAlphaNumericAscii); BENCHMARK(MatchSubstring); +BENCHMARK(SplitPattern); #ifdef ARROW_WITH_UTF8PROC BENCHMARK(Utf8Lower); BENCHMARK(Utf8Upper); diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc index a96716ad39c..e77a4cc5765 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -334,6 +334,90 @@ TYPED_TEST(TestStringKernels, MatchSubstring) { &options_double_char_2); } +TYPED_TEST(TestStringKernels, SplitBasics) { + SplitPatternOptions options{" "}; + // basics + this->CheckUnary("split_pattern", R"(["foo bar", "foo"])", list(this->type()), + R"([["foo", "bar"], ["foo"]])", &options); + // TODO: enable test when the following issue is fixed: + // https://issues.apache.org/jira/browse/ARROW-10208 + // this->CheckUnary("split_pattern", R"(["foo bar", "foo", null])", list(this->type()), + // R"([["foo", "bar"], ["foo"], null])", &options); + // edgy cases + this->CheckUnary("split_pattern", R"(["f o o "])", list(this->type()), + R"([["f", "", "o", "o", ""]])", &options); + this->CheckUnary("split_pattern", "[]", list(this->type()), "[]", &options); + // longer patterns + SplitPatternOptions options_long{"---"}; + this->CheckUnary("split_pattern", R"(["-foo---bar--", "---foo---b"])", + list(this->type()), R"([["-foo", "bar--"], ["", "foo", "b"]])", + &options_long); + SplitPatternOptions options_long_reverse{"---", -1, /*reverse=*/true}; + this->CheckUnary("split_pattern", R"(["-foo---bar--", "---foo---b"])", + list(this->type()), R"([["-foo", "bar--"], ["", "foo", "b"]])", + &options_long_reverse); +} + +TYPED_TEST(TestStringKernels, SplitMax) { + SplitPatternOptions options{"---", 2}; + SplitPatternOptions options_reverse{"---", 2, /*reverse=*/true}; + this->CheckUnary("split_pattern", R"(["foo---bar", "foo", "foo---bar------ar"])", + list(this->type()), + R"([["foo", "bar"], ["foo"], ["foo", "bar", "---ar"]])", &options); + this->CheckUnary( + "split_pattern", R"(["foo---bar", "foo", "foo---bar------ar"])", list(this->type()), + R"([["foo", "bar"], ["foo"], ["foo---bar", "", "ar"]])", &options_reverse); +} + +TYPED_TEST(TestStringKernels, SplitWhitespaceAscii) { + SplitOptions options; + SplitOptions options_max{1}; + // basics + this->CheckUnary("ascii_split_whitespace", R"(["foo bar", "foo bar \tba"])", + list(this->type()), R"([["foo", "bar"], ["foo", "bar", "ba"]])", + &options); + this->CheckUnary("ascii_split_whitespace", R"(["foo bar", "foo bar \tba"])", + list(this->type()), R"([["foo", "bar"], ["foo", "bar \tba"]])", + &options_max); +} + +TYPED_TEST(TestStringKernels, SplitWhitespaceAsciiReverse) { + SplitOptions options{-1, /*reverse=*/true}; + SplitOptions options_max{1, /*reverse=*/true}; + // basics + this->CheckUnary("ascii_split_whitespace", R"(["foo bar", "foo bar \tba"])", + list(this->type()), R"([["foo", "bar"], ["foo", "bar", "ba"]])", + &options); + this->CheckUnary("ascii_split_whitespace", R"(["foo bar", "foo bar \tba"])", + list(this->type()), R"([["foo", "bar"], ["foo bar", "ba"]])", + &options_max); +} + +TYPED_TEST(TestStringKernels, SplitWhitespaceUTF8) { + SplitOptions options; + SplitOptions options_max{1}; + // \xe2\x80\x88 is punctuation space + this->CheckUnary("utf8_split_whitespace", + "[\"foo bar\", \"foo\xe2\x80\x88 bar \\tba\"]", list(this->type()), + R"([["foo", "bar"], ["foo", "bar", "ba"]])", &options); + this->CheckUnary("utf8_split_whitespace", + "[\"foo bar\", \"foo\xe2\x80\x88 bar \\tba\"]", list(this->type()), + R"([["foo", "bar"], ["foo", "bar \tba"]])", &options_max); +} + +TYPED_TEST(TestStringKernels, SplitWhitespaceUTF8Reverse) { + SplitOptions options{-1, /*reverse=*/true}; + SplitOptions options_max{1, /*reverse=*/true}; + // \xe2\x80\x88 is punctuation space + this->CheckUnary("utf8_split_whitespace", + "[\"foo bar\", \"foo\xe2\x80\x88 bar \\tba\"]", list(this->type()), + R"([["foo", "bar"], ["foo", "bar", "ba"]])", &options); + this->CheckUnary("utf8_split_whitespace", + "[\"foo bar\", \"foo\xe2\x80\x88 bar \\tba\"]", list(this->type()), + "[[\"foo\", \"bar\"], [\"foo\xe2\x80\x88 bar\", \"ba\"]]", + &options_max); +} + TYPED_TEST(TestStringKernels, Strptime) { std::string input1 = R"(["5/1/2020", null, "12/11/1900"])"; std::string output1 = R"(["2020-05-01", null, "1900-12-11"])"; diff --git a/cpp/src/arrow/util/utf8.h b/cpp/src/arrow/util/utf8.h index c089fa7fff6..afc6c172e48 100644 --- a/cpp/src/arrow/util/utf8.h +++ b/cpp/src/arrow/util/utf8.h @@ -321,6 +321,18 @@ static inline bool Utf8IsContinuation(const uint8_t codeunit) { return (codeunit & 0xC0) == 0x80; // upper two bits should be 10 } +static inline bool Utf8Is2ByteStart(const uint8_t codeunit) { + return (codeunit & 0xE0) == 0xC0; // upper three bits should be 110 +} + +static inline bool Utf8Is3ByteStart(const uint8_t codeunit) { + return (codeunit & 0xF0) == 0xE0; // upper four bits should be 1110 +} + +static inline bool Utf8Is4ByteStart(const uint8_t codeunit) { + return (codeunit & 0xF8) == 0xF0; // upper five bits should be 11110 +} + static inline uint8_t* UTF8Encode(uint8_t* str, uint32_t codepoint) { if (codepoint < 0x80) { *str++ = codepoint; @@ -389,6 +401,45 @@ static inline bool UTF8Decode(const uint8_t** data, uint32_t* codepoint) { return true; } +static inline bool UTF8DecodeReverse(const uint8_t** data, uint32_t* codepoint) { + const uint8_t* str = *data; + if (*str < 0x80) { // ascci + *codepoint = *str--; + } else { + if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) { + return false; + } + uint8_t code_unit_N = (*str--) & 0x3F; // take last 6 bits + if (Utf8Is2ByteStart(*str)) { + uint8_t code_unit_1 = (*str--) & 0x1F; // take last 5 bits + *codepoint = (code_unit_1 << 6) + code_unit_N; + } else { + if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) { + return false; + } + uint8_t code_unit_Nmin1 = (*str--) & 0x3F; // take last 6 bits + if (Utf8Is3ByteStart(*str)) { + uint8_t code_unit_1 = (*str--) & 0x0F; // take last 4 bits + *codepoint = (code_unit_1 << 12) + (code_unit_Nmin1 << 6) + code_unit_N; + } else { + if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) { + return false; + } + uint8_t code_unit_Nmin2 = (*str--) & 0x3F; // take last 6 bits + if (ARROW_PREDICT_TRUE(Utf8Is4ByteStart(*str))) { + uint8_t code_unit_1 = (*str--) & 0x07; // take last 3 bits + *codepoint = (code_unit_1 << 18) + (code_unit_Nmin2 << 12) + + (code_unit_Nmin1 << 6) + code_unit_N; + } else { + return false; + } + } + } + } + *data = str; + return true; +} + template static inline bool UTF8Transform(const uint8_t* first, const uint8_t* last, uint8_t** destination, UnaryOperation&& unary_op) { diff --git a/cpp/src/arrow/util/utf8_util_test.cc b/cpp/src/arrow/util/utf8_util_test.cc index 167b402933d..44caf365089 100644 --- a/cpp/src/arrow/util/utf8_util_test.cc +++ b/cpp/src/arrow/util/utf8_util_test.cc @@ -374,5 +374,43 @@ TEST(WideStringToUTF8, Basics) { #endif } +TEST(UTF8DecodeReverse, Basics) { + auto CheckOk = [](const std::string& s) -> void { + const uint8_t* begin = reinterpret_cast(s.c_str()); + const uint8_t* end = begin + s.length(); + const uint8_t* i = end - 1; + uint32_t codepoint; + EXPECT_TRUE(UTF8DecodeReverse(&i, &codepoint)); + EXPECT_EQ(i, begin - 1); + }; + + // 0x80 == 0b10000000 + // 0xC0 == 0b11000000 + // 0xE0 == 0b11100000 + // 0xF0 == 0b11110000 + CheckOk("a"); + CheckOk("\xC0\x80"); + CheckOk("\xE0\x80\x80"); + CheckOk("\xF0\x80\x80\x80"); + + auto CheckInvalid = [](const std::string& s) -> void { + const uint8_t* begin = reinterpret_cast(s.c_str()); + const uint8_t* end = begin + s.length(); + const uint8_t* i = end - 1; + uint32_t codepoint; + EXPECT_FALSE(UTF8DecodeReverse(&i, &codepoint)); + }; + + // too many continuation code units + CheckInvalid("a\x80"); + CheckInvalid("\xC0\x80\x80"); + CheckInvalid("\xE0\x80\x80\x80"); + CheckInvalid("\xF0\x80\x80\x80\x80"); + // not enough continuation code units + CheckInvalid("\xC0"); + CheckInvalid("\xE0\x80"); + CheckInvalid("\xF0\x80\x80"); +} + } // namespace util } // namespace arrow diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index af2f485058c..c2b901ff8a0 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -395,6 +395,37 @@ Containment tests * \(3) Output is true iff the corresponding input element is equal to one of the elements in :member:`SetLookupOptions::value_set`. + +String splitting +~~~~~~~~~~~~~~~~ + +These functions split strings into lists of strings. All kernels can optionally +be configured with a ``max_splits`` and a ``reverse`` parameter, where +``max_splits == -1`` means no limit (the default). When ``reverse`` is true, +the splitting is done starting from the end of the string; this is only relevant +when a positive ``max_splits`` is given. + ++--------------------------+------------+-------------------------+-------------------+----------------------------------+---------+ +| Function name | Arity | Input types | Output type | Options class | Notes | ++==========================+============+=========================+===================+==================================+=========+ +| split_pattern | Unary | String-like | List-like | :struct:`SplitPatternOptions` | \(1) | ++--------------------------+------------+-------------------------+-------------------+----------------------------------+---------+ +| utf8_split_whitespace | Unary | String-like | List-like | :struct:`SplitOptions` | \(2) | ++--------------------------+------------+-------------------------+-------------------+----------------------------------+---------+ +| ascii_split_whitespace | Unary | String-like | List-like | :struct:`SplitOptions` | \(3) | ++--------------------------+------------+-------------------------+-------------------+----------------------------------+---------+ + +* \(1) The string is split when an exact pattern is found (the pattern itself + is not included in the output). + +* \(2) A non-zero length sequence of Unicode defined whitespace codepoints + is seen as separator. + +* \(3) A non-zero length sequence of ASCII defined whitespace bytes + (``'\t'``, ``'\n'``, ``'\v'``, ``'\f'``, ``'\r'`` and ``' '``) is seen + as separator. + + Structural transforms ~~~~~~~~~~~~~~~~~~~~~ diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 8aeceecb90d..af9eb473cf5 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -785,3 +785,37 @@ cdef class _VarianceOptions(FunctionOptions): class VarianceOptions(_VarianceOptions): def __init__(self, *, ddof=0): self._set_options(ddof) + + +cdef class _SplitOptions(FunctionOptions): + cdef: + unique_ptr[CSplitOptions] split_options + + cdef const CFunctionOptions* get_options(self) except NULL: + return self.split_options.get() + + def _set_options(self, max_splits, reverse): + self.split_options.reset( + new CSplitOptions(max_splits, reverse)) + + +class SplitOptions(_SplitOptions): + def __init__(self, *, max_splits=-1, reverse=False): + self._set_options(max_splits, reverse) + + +cdef class _SplitPatternOptions(FunctionOptions): + cdef: + unique_ptr[CSplitPatternOptions] split_pattern_options + + cdef const CFunctionOptions* get_options(self) except NULL: + return self.split_pattern_options.get() + + def _set_options(self, pattern, max_splits, reverse): + self.split_pattern_options.reset( + new CSplitPatternOptions(tobytes(pattern), max_splits, reverse)) + + +class SplitPatternOptions(_SplitPatternOptions): + def __init__(self, *, pattern, max_splits=-1, reverse=False): + self._set_options(pattern, max_splits, reverse) diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index 6a0ea4e80ef..48bf9b25022 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -31,6 +31,8 @@ CountOptions, FilterOptions, MatchSubstringOptions, + SplitOptions, + SplitPatternOptions, MinMaxOptions, PartitionNthOptions, SetLookupOptions, diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 30d37f54e0c..0b8181baa7a 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1684,6 +1684,18 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: CMatchSubstringOptions(c_string pattern) c_string pattern + cdef cppclass CSplitOptions \ + "arrow::compute::SplitOptions"(CFunctionOptions): + CSplitOptions(int64_t max_splits, c_bool reverse) + int64_t max_splits + c_bool reverse + + cdef cppclass CSplitPatternOptions \ + "arrow::compute::SplitPatternOptions"(CSplitOptions): + CSplitPatternOptions(c_string pattern, int64_t max_splits, + c_bool reverse) + c_string pattern + cdef cppclass CCastOptions" arrow::compute::CastOptions"(CFunctionOptions): CCastOptions() CCastOptions(c_bool safe) diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 3c29d8a5259..048245cf871 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -259,6 +259,51 @@ def test_match_substring(): assert expected.equals(result) +def test_split_pattern(): + arr = pa.array(["-foo---bar--", "---foo---b"]) + result = pc.split_pattern(arr, pattern="---") + expected = pa.array([["-foo", "bar--"], ["", "foo", "b"]]) + assert expected.equals(result) + + result = pc.split_pattern(arr, pattern="---", max_splits=1) + expected = pa.array([["-foo", "bar--"], ["", "foo---b"]]) + assert expected.equals(result) + + result = pc.split_pattern(arr, pattern="---", max_splits=1, reverse=True) + expected = pa.array([["-foo", "bar--"], ["---foo", "b"]]) + assert expected.equals(result) + + +def test_split_whitespace_utf8(): + arr = pa.array(["foo bar", " foo \u3000\tb"]) + result = pc.utf8_split_whitespace(arr) + expected = pa.array([["foo", "bar"], ["", "foo", "b"]]) + assert expected.equals(result) + + result = pc.utf8_split_whitespace(arr, max_splits=1) + expected = pa.array([["foo", "bar"], ["", "foo \u3000\tb"]]) + assert expected.equals(result) + + result = pc.utf8_split_whitespace(arr, max_splits=1, reverse=True) + expected = pa.array([["foo", "bar"], [" foo", "b"]]) + assert expected.equals(result) + + +def test_split_whitespace_ascii(): + arr = pa.array(["foo bar", " foo \u3000\tb"]) + result = pc.ascii_split_whitespace(arr) + expected = pa.array([["foo", "bar"], ["", "foo", "\u3000", "b"]]) + assert expected.equals(result) + + result = pc.ascii_split_whitespace(arr, max_splits=1) + expected = pa.array([["foo", "bar"], ["", "foo \u3000\tb"]]) + assert expected.equals(result) + + result = pc.ascii_split_whitespace(arr, max_splits=1, reverse=True) + expected = pa.array([["foo", "bar"], [" foo \u3000", "b"]]) + assert expected.equals(result) + + def test_min_max(): # An example generated function wrapper with possible options data = [4, 5, 6, None, 1]