diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/UTF8StringBuilder.java similarity index 80% rename from sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/UTF8StringBuilder.java index f0f66bae245fd..481ea89090b2a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/UTF8StringBuilder.java @@ -15,9 +15,8 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.expressions.codegen; +package org.apache.spark.unsafe; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.UTF8String; @@ -34,7 +33,18 @@ public class UTF8StringBuilder { public UTF8StringBuilder() { // Since initial buffer size is 16 in `StringBuilder`, we set the same size here - this.buffer = new byte[16]; + this(16); + } + + public UTF8StringBuilder(int initialSize) { + if (initialSize < 0) { + throw new IllegalArgumentException("Size must be non-negative"); + } + if (initialSize > ARRAY_MAX) { + throw new IllegalArgumentException( + "Size " + initialSize + " exceeded maximum size of " + ARRAY_MAX); + } + this.buffer = new byte[initialSize]; } // Grows the buffer by at least `neededSize` @@ -72,6 +82,17 @@ public void append(String value) { append(UTF8String.fromString(value)); } + public void appendBytes(Object base, long offset, int length) { + grow(length); + Platform.copyMemory( + base, + offset, + buffer, + cursor, + length); + cursor += length; + } + public UTF8String build() { return UTF8String.fromBytes(buffer, 0, totalSize()); } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 3a3bfc4a94bb3..00c98c91a6d7f 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -32,6 +32,7 @@ import com.google.common.primitives.Ints; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.UTF8StringBuilder; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; @@ -973,12 +974,29 @@ public UTF8String[] split(UTF8String pattern, int limit) { } public UTF8String replace(UTF8String search, UTF8String replace) { - if (EMPTY_UTF8.equals(search)) { + // This implementation is loosely based on commons-lang3's StringUtils.replace(). + if (numBytes == 0 || search.numBytes == 0) { return this; } - String replaced = toString().replace( - search.toString(), replace.toString()); - return fromString(replaced); + // Find the first occurrence of the search string. + int start = 0; + int end = this.find(search, start); + if (end == -1) { + // Search string was not found, so string is unchanged. + return this; + } + // At least one match was found. Estimate space needed for result. + // The 16x multiplier here is chosen to match commons-lang3's implementation. + int increase = Math.max(0, replace.numBytes - search.numBytes) * 16; + final UTF8StringBuilder buf = new UTF8StringBuilder(numBytes + increase); + while (end != -1) { + buf.appendBytes(this.base, this.offset + start, end - start); + buf.append(replace); + start = end + search.numBytes; + end = this.find(search, start); + } + buf.appendBytes(this.base, this.offset + start, numBytes - start); + return buf.build(); } // TODO: Need to use `Code Point` here instead of Char in case the character longer than 2 bytes diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index cf9cc6b1800a9..bc75fa9e724a0 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -403,6 +403,44 @@ public void split() { new UTF8String[]{fromString("ab"), fromString("def,ghi,")})); } + @Test + public void replace() { + assertEquals( + fromString("re123ace"), + fromString("replace").replace(fromString("pl"), fromString("123"))); + assertEquals( + fromString("reace"), + fromString("replace").replace(fromString("pl"), fromString(""))); + assertEquals( + fromString("replace"), + fromString("replace").replace(fromString(""), fromString("123"))); + // tests for multiple replacements + assertEquals( + fromString("a12ca12c"), + fromString("abcabc").replace(fromString("b"), fromString("12"))); + assertEquals( + fromString("adad"), + fromString("abcdabcd").replace(fromString("bc"), fromString(""))); + // tests for single character search and replacement strings + assertEquals( + fromString("AbcAbc"), + fromString("abcabc").replace(fromString("a"), fromString("A"))); + assertEquals( + fromString("abcabc"), + fromString("abcabc").replace(fromString("Z"), fromString("A"))); + // Tests with non-ASCII characters + assertEquals( + fromString("花ab界"), + fromString("花花世界").replace(fromString("花世"), fromString("ab"))); + assertEquals( + fromString("a水c"), + fromString("a火c").replace(fromString("火"), fromString("水"))); + // Tests for a large number of replacements, triggering UTF8StringBuilder resize + assertEquals( + fromString("abcd").repeat(17), + fromString("a").repeat(17).replace(fromString("a"), fromString("abcd"))); + } + @Test public void levenshteinDistance() { assertEquals(0, EMPTY_UTF8.levenshteinDistance(EMPTY_UTF8)); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index f8c1102953ab3..969128838eba4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.DateTimeUtils._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.UTF8StringBuilder import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} import org.apache.spark.unsafe.types.UTF8String.{IntWrapper, LongWrapper} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 41d9b06ed1d01..8477e63135e30 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.DateTimeUtils._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.UTF8StringBuilder import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH import org.apache.spark.unsafe.types.{ByteArray, UTF8String}