Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions.codegen;
package org.apache.spark.unsafe;

import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.types.UTF8String;

Expand All @@ -34,7 +33,18 @@ public class UTF8StringBuilder {

public UTF8StringBuilder() {
// Since initial buffer size is 16 in `StringBuilder`, we set the same size here
this.buffer = new byte[16];
this(16);
}

public UTF8StringBuilder(int initialSize) {
if (initialSize < 0) {
throw new IllegalArgumentException("Size must be non-negative");
}
if (initialSize > ARRAY_MAX) {
throw new IllegalArgumentException(
"Size " + initialSize + " exceeded maximum size of " + ARRAY_MAX);
}
this.buffer = new byte[initialSize];
}

// Grows the buffer by at least `neededSize`
Expand Down Expand Up @@ -72,6 +82,17 @@ public void append(String value) {
append(UTF8String.fromString(value));
}

public void appendBytes(Object base, long offset, int length) {
grow(length);
Platform.copyMemory(
base,
offset,
buffer,
cursor,
length);
cursor += length;
}

public UTF8String build() {
return UTF8String.fromBytes(buffer, 0, totalSize());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import com.google.common.primitives.Ints;

import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.UTF8StringBuilder;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.hash.Murmur3_x86_32;

Expand Down Expand Up @@ -973,12 +974,29 @@ public UTF8String[] split(UTF8String pattern, int limit) {
}

public UTF8String replace(UTF8String search, UTF8String replace) {
if (EMPTY_UTF8.equals(search)) {
// This implementation is loosely based on commons-lang3's StringUtils.replace().
if (numBytes == 0 || search.numBytes == 0) {
return this;
}
String replaced = toString().replace(
search.toString(), replace.toString());
return fromString(replaced);
// Find the first occurrence of the search string.
int start = 0;
int end = this.find(search, start);
if (end == -1) {
// Search string was not found, so string is unchanged.
return this;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One consideration here: do we need to make a defensive copy? If so, we can't do this optimization.

Why might we need to copy a UTF8String? The UTF8String instance itself is effectively immutable, but the underlying storage might be a region of potentially-not-exclusively-owned memory (either direct/off-heap memory or a region of a long[] array), so we might need to perform a copy in case we're going to buffer / otherwise hold onto the UTF8String past a point where the underlying underlying storage memory could be mutated.

I think the most common case to worry about would be a UTF8String which is backed by memory that is part of a larger UnsafeRow. If we're doing row-at-a-time processing and aren't holding onto this UTF8String across rows then I think we're ok since changes to rows' memory during single-row processing would impact many parts of Spark and would probably be detected. In the few places where we do hold references across evaluations / rows then we need to copy, but I suspect most places already do this: for example, see the regexp.clone() in the RegExpReplace expression.

My intuition is that we probably don't need to make a defensive copy here because I doubt we have parts of the code which specifically assume that replace() will copy (i.e. which are abusing replace() as a slow clone() mechanism). Put differently, I suspect that any code which would fail due to lack of copying in replace() is also vulnerable to this problem from other sources (including simply reading a string from a row without further modification), so I don't think we need to add extra copying here.

I'd love to get additional sets of eyes on this, though, and I'd ultimately be ok with changing return this to return this.clone() (and updating the other return this uses in UTF8String) if we conclude that this isn't safe (or are uncertain and want to err on the side of caution).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think your intuition is right here.

}
// At least one match was found. Estimate space needed for result.
// The 16x multiplier here is chosen to match commons-lang3's implementation.
int increase = Math.max(0, replace.numBytes - search.numBytes) * 16;
final UTF8StringBuilder buf = new UTF8StringBuilder(numBytes + increase);
while (end != -1) {
buf.appendBytes(this.base, this.offset + start, end - start);
buf.append(replace);
start = end + search.numBytes;
end = this.find(search, start);
}
buf.appendBytes(this.base, this.offset + start, numBytes - start);
return buf.build();
}

// TODO: Need to use `Code Point` here instead of Char in case the character longer than 2 bytes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,44 @@ public void split() {
new UTF8String[]{fromString("ab"), fromString("def,ghi,")}));
}

@Test
public void replace() {
assertEquals(
fromString("re123ace"),
fromString("replace").replace(fromString("pl"), fromString("123")));
assertEquals(
fromString("reace"),
fromString("replace").replace(fromString("pl"), fromString("")));
assertEquals(
fromString("replace"),
fromString("replace").replace(fromString(""), fromString("123")));
// tests for multiple replacements
assertEquals(
fromString("a12ca12c"),
fromString("abcabc").replace(fromString("b"), fromString("12")));
assertEquals(
fromString("adad"),
fromString("abcdabcd").replace(fromString("bc"), fromString("")));
// tests for single character search and replacement strings
assertEquals(
fromString("AbcAbc"),
fromString("abcabc").replace(fromString("a"), fromString("A")));
assertEquals(
fromString("abcabc"),
fromString("abcabc").replace(fromString("Z"), fromString("A")));
// Tests with non-ASCII characters
assertEquals(
fromString("花ab界"),
fromString("花花世界").replace(fromString("花世"), fromString("ab")));
assertEquals(
fromString("a水c"),
fromString("a火c").replace(fromString("火"), fromString("水")));
// Tests for a large number of replacements, triggering UTF8StringBuilder resize
assertEquals(
fromString("abcd").repeat(17),
fromString("a").repeat(17).replace(fromString("a"), fromString("abcd")));
}

@Test
public void levenshteinDistance() {
assertEquals(0, EMPTY_UTF8.levenshteinDistance(EMPTY_UTF8));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.UTF8StringBuilder
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import org.apache.spark.unsafe.types.UTF8String.{IntWrapper, LongWrapper}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.UTF8StringBuilder
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
Expand Down