Skip to content

Commit d82c695

Browse files
uros-dbcloud-fan
authored andcommitted
[SPARK-49207][SQL] Fix one-to-many case mapping in SplitPart and StringSplitSQL
### What changes were proposed in this pull request? Fix the following string expressions to handle one-to-many case mapping properly: - SplitPart - StringSplitSQL Examples of incorrect results (under `UTF8_LCASE` collation): ``` SplitPart("Ai\u0307B", "İ", 2) // returns: "\u0307B" (incorrect), instead of: "B" (correct) SplitPart("AİB", "i\u0307", 1) // returns: "AİB", instead of: "A", "B" (correct) StringSplitSQL("Ai\u0307B", "İ") // returns: ["A", "\u0307B"] (incorrect), instead of: ["A", "B"] (correct) StringSplitSQL("AİB", "i\u0307") // returns: ["AİB"] (incorrect), instead of: ["A", "B"] (correct) ``` ### Why are the changes needed? Currently, some string expressions are giving wrong results when working with one-to-many case mapping. ### Does this PR introduce _any_ user-facing change? Yes, this expression will now work properly with surrogate pairs: `split_part`. ### How was this patch tested? New tests in `CollationSupportSuite`. ### Was this patch authored or co-authored using generative AI tooling? Yes. Closes #47715 from uros-db/fix-splitpart. Authored-by: Uros Bojanic <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent e77c8fb commit d82c695

File tree

2 files changed

+161
-33
lines changed

2 files changed

+161
-33
lines changed

common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
import java.util.Iterator;
3737
import java.util.List;
3838
import java.util.Map;
39-
import java.util.regex.Pattern;
4039

4140
/**
4241
* Utility class for collation-aware UTF8String operations.
@@ -1208,24 +1207,43 @@ public static UTF8String[] splitSQL(final UTF8String input, final UTF8String del
12081207

12091208
public static UTF8String[] lowercaseSplitSQL(final UTF8String string, final UTF8String delimiter,
12101209
final int limit) {
1211-
if (delimiter.numBytes() == 0) return new UTF8String[] { string };
1212-
if (string.numBytes() == 0) return new UTF8String[] { UTF8String.EMPTY_UTF8 };
1213-
Pattern pattern = Pattern.compile(Pattern.quote(delimiter.toString()),
1214-
CollationSupport.lowercaseRegexFlags);
1215-
String[] splits = pattern.split(string.toString(), limit);
1216-
UTF8String[] res = new UTF8String[splits.length];
1217-
for (int i = 0; i < res.length; i++) {
1218-
res[i] = UTF8String.fromString(splits[i]);
1210+
if (delimiter.numBytes() == 0) return new UTF8String[] { string };
1211+
if (string.numBytes() == 0) return new UTF8String[] { UTF8String.EMPTY_UTF8 };
1212+
1213+
List<UTF8String> strings = new ArrayList<>();
1214+
UTF8String lowercaseDelimiter = lowerCaseCodePoints(delimiter);
1215+
int startIndex = 0, nextMatch = 0, nextMatchLength;
1216+
while (nextMatch != MATCH_NOT_FOUND) {
1217+
if (limit > 0 && strings.size() == limit - 1) {
1218+
break;
1219+
}
1220+
nextMatch = lowercaseFind(string, lowercaseDelimiter, startIndex);
1221+
if (nextMatch != MATCH_NOT_FOUND) {
1222+
nextMatchLength = lowercaseMatchLengthFrom(string, lowercaseDelimiter, nextMatch);
1223+
strings.add(string.substring(startIndex, nextMatch));
1224+
startIndex = nextMatch + nextMatchLength;
12191225
}
1220-
return res;
1226+
}
1227+
if (startIndex <= string.numChars()) {
1228+
strings.add(string.substring(startIndex, string.numChars()));
1229+
}
1230+
if (limit == 0) {
1231+
// Remove trailing empty strings
1232+
int i = strings.size() - 1;
1233+
while (i >= 0 && strings.get(i).numBytes() == 0) {
1234+
strings.remove(i);
1235+
i--;
1236+
}
1237+
}
1238+
return strings.toArray(new UTF8String[0]);
12211239
}
12221240

12231241
public static UTF8String[] icuSplitSQL(final UTF8String string, final UTF8String delimiter,
12241242
final int limit, final int collationId) {
12251243
if (delimiter.numBytes() == 0) return new UTF8String[] { string };
12261244
if (string.numBytes() == 0) return new UTF8String[] { UTF8String.EMPTY_UTF8 };
12271245
List<UTF8String> strings = new ArrayList<>();
1228-
String target = string.toString(), pattern = delimiter.toString();
1246+
String target = string.toValidString(), pattern = delimiter.toValidString();
12291247
StringSearch stringSearch = CollationFactory.getStringSearch(target, pattern, collationId);
12301248
int start = 0, end;
12311249
while ((end = stringSearch.next()) != StringSearch.DONE) {

common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java

Lines changed: 132 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -567,12 +567,17 @@ public void testEndsWith() throws SparkException {
567567
assertEndsWith("the İo", "İo", "UTF8_LCASE", true);
568568
}
569569

570+
/**
571+
* Verify the behaviour of the `StringSplitSQL` collation support class.
572+
*/
573+
570574
private void assertStringSplitSQL(String str, String delimiter, String collationName,
571575
UTF8String[] expected) throws SparkException {
572576
UTF8String s = UTF8String.fromString(str);
573577
UTF8String d = UTF8String.fromString(delimiter);
574578
int collationId = CollationFactory.collationNameToId(collationName);
575-
assertArrayEquals(expected, CollationSupport.StringSplitSQL.exec(s, d, collationId));
579+
UTF8String[] result = CollationSupport.StringSplitSQL.exec(s, d, collationId);
580+
assertArrayEquals(expected, result);
576581
}
577582

578583
@Test
@@ -590,7 +595,21 @@ public void testStringSplitSQL() throws SparkException {
590595
var array_A_B = new UTF8String[] { UTF8String.fromString("A"), UTF8String.fromString("B") };
591596
var array_a_e = new UTF8String[] { UTF8String.fromString("ä"), UTF8String.fromString("e") };
592597
var array_Aa_bB = new UTF8String[] { UTF8String.fromString("Aa"), UTF8String.fromString("bB") };
593-
// Edge cases
598+
var array_Turkish_uppercase_dotted_I = new UTF8String[] { UTF8String.fromString("İ") };
599+
var array_Turkish_lowercase_dotted_i = new UTF8String[] { UTF8String.fromString("i\u0307") };
600+
var array_i = new UTF8String[] { UTF8String.fromString("i"), UTF8String.fromString("") };
601+
var array_dot = new UTF8String[] { UTF8String.fromString(""), UTF8String.fromString("\u0307") };
602+
var array_AiB = new UTF8String[] { UTF8String.fromString("Ai\u0307B") };
603+
var array_AIB = new UTF8String[] { UTF8String.fromString("AİB") };
604+
var array_small_nonfinal_sigma = new UTF8String[] { UTF8String.fromString("σ") };
605+
var array_small_final_sigma = new UTF8String[] { UTF8String.fromString("ς") };
606+
var array_capital_sigma = new UTF8String[] { UTF8String.fromString("Σ") };
607+
var array_a_b_c = new UTF8String[] { UTF8String.fromString("a"), UTF8String.fromString("b"),
608+
UTF8String.fromString("c") };
609+
var array_emojis = new UTF8String[] { UTF8String.fromString("😀"), UTF8String.fromString("😄") };
610+
var array_AOB = new UTF8String[] { UTF8String.fromString("A𐐅B") };
611+
var array_AoB = new UTF8String[] { UTF8String.fromString("A𐐭B") };
612+
// Empty strings.
594613
assertStringSplitSQL("", "", "UTF8_BINARY", empty_match);
595614
assertStringSplitSQL("abc", "", "UTF8_BINARY", array_abc);
596615
assertStringSplitSQL("", "abc", "UTF8_BINARY", empty_match);
@@ -603,7 +622,7 @@ public void testStringSplitSQL() throws SparkException {
603622
assertStringSplitSQL("", "", "UNICODE_CI", empty_match);
604623
assertStringSplitSQL("abc", "", "UNICODE_CI", array_abc);
605624
assertStringSplitSQL("", "abc", "UNICODE_CI", empty_match);
606-
// Basic tests
625+
// Basic tests.
607626
assertStringSplitSQL("1a2", "a", "UTF8_BINARY", array_1_2);
608627
assertStringSplitSQL("1a2", "A", "UTF8_BINARY", array_1a2);
609628
assertStringSplitSQL("1a2", "b", "UTF8_BINARY", array_1a2);
@@ -617,25 +636,7 @@ public void testStringSplitSQL() throws SparkException {
617636
assertStringSplitSQL("1a2", "A", "UNICODE_CI", array_1_2);
618637
assertStringSplitSQL("1a2", "1A2", "UNICODE_CI", full_match);
619638
assertStringSplitSQL("1a2", "123", "UNICODE_CI", array_1a2);
620-
// Case variation
621-
assertStringSplitSQL("AaXbB", "x", "UTF8_BINARY", array_AaXbB);
622-
assertStringSplitSQL("AaXbB", "X", "UTF8_BINARY", array_Aa_bB);
623-
assertStringSplitSQL("AaXbB", "axb", "UNICODE", array_AaXbB);
624-
assertStringSplitSQL("AaXbB", "aXb", "UNICODE", array_A_B);
625-
assertStringSplitSQL("AaXbB", "axb", "UTF8_LCASE", array_A_B);
626-
assertStringSplitSQL("AaXbB", "AXB", "UTF8_LCASE", array_A_B);
627-
assertStringSplitSQL("AaXbB", "axb", "UNICODE_CI", array_A_B);
628-
assertStringSplitSQL("AaXbB", "AxB", "UNICODE_CI", array_A_B);
629-
// Accent variation
630-
assertStringSplitSQL("aBcDe", "bćd", "UTF8_BINARY", array_aBcDe);
631-
assertStringSplitSQL("aBcDe", "BćD", "UTF8_BINARY", array_aBcDe);
632-
assertStringSplitSQL("aBcDe", "abćde", "UNICODE", array_aBcDe);
633-
assertStringSplitSQL("aBcDe", "aBćDe", "UNICODE", array_aBcDe);
634-
assertStringSplitSQL("aBcDe", "bćd", "UTF8_LCASE", array_aBcDe);
635-
assertStringSplitSQL("aBcDe", "BĆD", "UTF8_LCASE", array_aBcDe);
636-
assertStringSplitSQL("aBcDe", "abćde", "UNICODE_CI", array_aBcDe);
637-
assertStringSplitSQL("aBcDe", "AbĆdE", "UNICODE_CI", array_aBcDe);
638-
// Variable byte length characters
639+
// Advanced tests.
639640
assertStringSplitSQL("äb世De", "b世D", "UTF8_BINARY", array_a_e);
640641
assertStringSplitSQL("äb世De", "B世d", "UTF8_BINARY", array_special);
641642
assertStringSplitSQL("äbćδe", "bćδ", "UTF8_BINARY", array_a_e);
@@ -652,6 +653,115 @@ public void testStringSplitSQL() throws SparkException {
652653
assertStringSplitSQL("äb世De", "AB世dE", "UNICODE_CI", array_special);
653654
assertStringSplitSQL("äbćδe", "ÄbćδE", "UNICODE_CI", full_match);
654655
assertStringSplitSQL("äbćδe", "ÄBcΔÉ", "UNICODE_CI", array_abcde);
656+
// Case variation.
657+
assertStringSplitSQL("AaXbB", "x", "UTF8_BINARY", array_AaXbB);
658+
assertStringSplitSQL("AaXbB", "X", "UTF8_BINARY", array_Aa_bB);
659+
assertStringSplitSQL("AaXbB", "axb", "UNICODE", array_AaXbB);
660+
assertStringSplitSQL("AaXbB", "aXb", "UNICODE", array_A_B);
661+
assertStringSplitSQL("AaXbB", "axb", "UTF8_LCASE", array_A_B);
662+
assertStringSplitSQL("AaXbB", "AXB", "UTF8_LCASE", array_A_B);
663+
assertStringSplitSQL("AaXbB", "axb", "UNICODE_CI", array_A_B);
664+
assertStringSplitSQL("AaXbB", "AxB", "UNICODE_CI", array_A_B);
665+
// Accent variation.
666+
assertStringSplitSQL("aBcDe", "bćd", "UTF8_BINARY", array_aBcDe);
667+
assertStringSplitSQL("aBcDe", "BćD", "UTF8_BINARY", array_aBcDe);
668+
assertStringSplitSQL("aBcDe", "abćde", "UNICODE", array_aBcDe);
669+
assertStringSplitSQL("aBcDe", "aBćDe", "UNICODE", array_aBcDe);
670+
assertStringSplitSQL("aBcDe", "bćd", "UTF8_LCASE", array_aBcDe);
671+
assertStringSplitSQL("aBcDe", "BĆD", "UTF8_LCASE", array_aBcDe);
672+
assertStringSplitSQL("aBcDe", "abćde", "UNICODE_CI", array_aBcDe);
673+
assertStringSplitSQL("aBcDe", "AbĆdE", "UNICODE_CI", array_aBcDe);
674+
// One-to-many case mapping (e.g. Turkish dotted I).
675+
assertStringSplitSQL("İ", "i", "UTF8_BINARY", array_Turkish_uppercase_dotted_I);
676+
assertStringSplitSQL("İ", "i", "UTF8_LCASE", array_Turkish_uppercase_dotted_I);
677+
assertStringSplitSQL("İ", "i", "UNICODE", array_Turkish_uppercase_dotted_I);
678+
assertStringSplitSQL("İ", "i", "UNICODE_CI", array_Turkish_uppercase_dotted_I);
679+
assertStringSplitSQL("İ", "\u0307", "UTF8_BINARY", array_Turkish_uppercase_dotted_I);
680+
assertStringSplitSQL("İ", "\u0307", "UTF8_LCASE", array_Turkish_uppercase_dotted_I);
681+
assertStringSplitSQL("İ", "\u0307", "UNICODE", array_Turkish_uppercase_dotted_I);
682+
assertStringSplitSQL("İ", "\u0307", "UNICODE_CI", array_Turkish_uppercase_dotted_I);
683+
assertStringSplitSQL("i\u0307", "i", "UTF8_BINARY", array_dot);
684+
assertStringSplitSQL("i\u0307", "i", "UTF8_LCASE", array_dot);
685+
assertStringSplitSQL("i\u0307", "i", "UNICODE", array_Turkish_lowercase_dotted_i);
686+
assertStringSplitSQL("i\u0307", "i", "UNICODE_CI", array_Turkish_lowercase_dotted_i);
687+
assertStringSplitSQL("i\u0307", "\u0307", "UTF8_BINARY", array_i);
688+
assertStringSplitSQL("i\u0307", "\u0307", "UTF8_LCASE", array_i);
689+
assertStringSplitSQL("i\u0307", "\u0307", "UNICODE", array_Turkish_lowercase_dotted_i);
690+
assertStringSplitSQL("i\u0307", "\u0307", "UNICODE_CI", array_Turkish_lowercase_dotted_i);
691+
assertStringSplitSQL("AİB", "İ", "UTF8_BINARY", array_A_B);
692+
assertStringSplitSQL("AİB", "İ", "UTF8_LCASE", array_A_B);
693+
assertStringSplitSQL("AİB", "İ", "UNICODE", array_A_B);
694+
assertStringSplitSQL("AİB", "İ", "UNICODE_CI", array_A_B);
695+
assertStringSplitSQL("AİB", "i\u0307", "UTF8_BINARY", array_AIB);
696+
assertStringSplitSQL("AİB", "i\u0307", "UTF8_LCASE", array_A_B);
697+
assertStringSplitSQL("AİB", "i\u0307", "UNICODE", array_AIB);
698+
assertStringSplitSQL("AİB", "i\u0307", "UNICODE_CI", array_A_B);
699+
assertStringSplitSQL("Ai\u0307B", "İ", "UTF8_BINARY", array_AiB);
700+
assertStringSplitSQL("Ai\u0307B", "İ", "UTF8_LCASE", array_A_B);
701+
assertStringSplitSQL("Ai\u0307B", "İ", "UNICODE", array_AiB);
702+
assertStringSplitSQL("Ai\u0307B", "İ", "UNICODE_CI", array_A_B);
703+
assertStringSplitSQL("Ai\u0307B", "i\u0307", "UTF8_BINARY", array_A_B);
704+
assertStringSplitSQL("Ai\u0307B", "i\u0307", "UTF8_LCASE", array_A_B);
705+
assertStringSplitSQL("Ai\u0307B", "i\u0307", "UNICODE", array_A_B);
706+
assertStringSplitSQL("Ai\u0307B", "i\u0307", "UNICODE_CI", array_A_B);
707+
// Conditional case mapping (e.g. Greek sigmas).
708+
assertStringSplitSQL("σ", "σ", "UTF8_BINARY", full_match);
709+
assertStringSplitSQL("σ", "σ", "UTF8_LCASE", full_match);
710+
assertStringSplitSQL("σ", "σ", "UNICODE", full_match);
711+
assertStringSplitSQL("σ", "σ", "UNICODE_CI", full_match);
712+
assertStringSplitSQL("σ", "ς", "UTF8_BINARY", array_small_nonfinal_sigma);
713+
assertStringSplitSQL("σ", "ς", "UTF8_LCASE", full_match);
714+
assertStringSplitSQL("σ", "ς", "UNICODE", array_small_nonfinal_sigma);
715+
assertStringSplitSQL("σ", "ς", "UNICODE_CI", full_match);
716+
assertStringSplitSQL("σ", "Σ", "UTF8_BINARY", array_small_nonfinal_sigma);
717+
assertStringSplitSQL("σ", "Σ", "UTF8_LCASE", full_match);
718+
assertStringSplitSQL("σ", "Σ", "UNICODE", array_small_nonfinal_sigma);
719+
assertStringSplitSQL("σ", "Σ", "UNICODE_CI", full_match);
720+
assertStringSplitSQL("ς", "σ", "UTF8_BINARY", array_small_final_sigma);
721+
assertStringSplitSQL("ς", "σ", "UTF8_LCASE", full_match);
722+
assertStringSplitSQL("ς", "σ", "UNICODE", array_small_final_sigma);
723+
assertStringSplitSQL("ς", "σ", "UNICODE_CI", full_match);
724+
assertStringSplitSQL("ς", "ς", "UTF8_BINARY", full_match);
725+
assertStringSplitSQL("ς", "ς", "UTF8_LCASE", full_match);
726+
assertStringSplitSQL("ς", "ς", "UNICODE", full_match);
727+
assertStringSplitSQL("ς", "ς", "UNICODE_CI", full_match);
728+
assertStringSplitSQL("ς", "Σ", "UTF8_BINARY", array_small_final_sigma);
729+
assertStringSplitSQL("ς", "Σ", "UTF8_LCASE", full_match);
730+
assertStringSplitSQL("ς", "Σ", "UNICODE", array_small_final_sigma);
731+
assertStringSplitSQL("ς", "Σ", "UNICODE_CI", full_match);
732+
assertStringSplitSQL("Σ", "σ", "UTF8_BINARY", array_capital_sigma);
733+
assertStringSplitSQL("Σ", "σ", "UTF8_LCASE", full_match);
734+
assertStringSplitSQL("Σ", "σ", "UNICODE", array_capital_sigma);
735+
assertStringSplitSQL("Σ", "σ", "UNICODE_CI", full_match);
736+
assertStringSplitSQL("Σ", "ς", "UTF8_BINARY", array_capital_sigma);
737+
assertStringSplitSQL("Σ", "ς", "UTF8_LCASE", full_match);
738+
assertStringSplitSQL("Σ", "ς", "UNICODE", array_capital_sigma);
739+
assertStringSplitSQL("Σ", "ς", "UNICODE_CI", full_match);
740+
assertStringSplitSQL("Σ", "Σ", "UTF8_BINARY", full_match);
741+
assertStringSplitSQL("Σ", "Σ", "UTF8_LCASE", full_match);
742+
assertStringSplitSQL("Σ", "Σ", "UNICODE", full_match);
743+
assertStringSplitSQL("Σ", "Σ", "UNICODE_CI", full_match);
744+
// Surrogate pairs.
745+
assertStringSplitSQL("a🙃b🙃c", "🙃", "UTF8_BINARY", array_a_b_c);
746+
assertStringSplitSQL("a🙃b🙃c", "🙃", "UTF8_LCASE", array_a_b_c);
747+
assertStringSplitSQL("a🙃b🙃c", "🙃", "UNICODE", array_a_b_c);
748+
assertStringSplitSQL("a🙃b🙃c", "🙃", "UNICODE_CI", array_a_b_c);
749+
assertStringSplitSQL("😀😆😃😄", "😆😃", "UTF8_BINARY", array_emojis);
750+
assertStringSplitSQL("😀😆😃😄", "😆😃", "UTF8_LCASE", array_emojis);
751+
assertStringSplitSQL("😀😆😃😄", "😆😃", "UNICODE", array_emojis);
752+
assertStringSplitSQL("😀😆😃😄", "😆😃", "UNICODE_CI", array_emojis);
753+
assertStringSplitSQL("A𐐅B", "𐐅", "UTF8_BINARY", array_A_B);
754+
assertStringSplitSQL("A𐐅B", "𐐅", "UTF8_LCASE", array_A_B);
755+
assertStringSplitSQL("A𐐅B", "𐐅", "UNICODE", array_A_B);
756+
assertStringSplitSQL("A𐐅B", "𐐅", "UNICODE_CI", array_A_B);
757+
assertStringSplitSQL("A𐐅B", "𐐭", "UTF8_BINARY", array_AOB);
758+
assertStringSplitSQL("A𐐅B", "𐐭", "UTF8_LCASE", array_A_B);
759+
assertStringSplitSQL("A𐐅B", "𐐭", "UNICODE", array_AOB);
760+
assertStringSplitSQL("A𐐅B", "𐐭", "UNICODE_CI", array_A_B);
761+
assertStringSplitSQL("A𐐭B", "𐐅", "UTF8_BINARY", array_AoB);
762+
assertStringSplitSQL("A𐐭B", "𐐅", "UTF8_LCASE", array_A_B);
763+
assertStringSplitSQL("A𐐭B", "𐐅", "UNICODE", array_AoB);
764+
assertStringSplitSQL("A𐐭B", "𐐅", "UNICODE_CI", array_A_B);
655765
}
656766

657767
private void assertUpper(String target, String collationName, String expected)

0 commit comments

Comments
 (0)