Skip to content

Commit f03e9c1

Browse files
committed
Play around with Unsafe implementations of more string methods.
1 parent ab68e08 commit f03e9c1

File tree

3 files changed

+104
-27
lines changed

3 files changed

+104
-27
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ package org.apache.spark.sql.types
1919

2020
import java.util.Arrays
2121

22-
import org.apache.spark.unsafe.PlatformDependent
23-
import org.apache.spark.unsafe.string.{UTF8StringPointer, UTF8StringMethods}
22+
import org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET
23+
import org.apache.spark.unsafe.array.ByteArrayMethods
24+
import org.apache.spark.unsafe.string.UTF8StringMethods
2425

2526
/**
2627
* A UTF-8 String, as internal representation of StringType in SparkSQL
@@ -35,8 +36,6 @@ final class UTF8String extends Ordered[UTF8String] with Serializable {
3536

3637
private[this] var bytes: Array[Byte] = _
3738

38-
private val pointer: UTF8StringPointer = new UTF8StringPointer
39-
4039
/**
4140
* Update the UTF8String with String.
4241
*/
@@ -49,7 +48,6 @@ final class UTF8String extends Ordered[UTF8String] with Serializable {
4948
*/
5049
def set(bytes: Array[Byte]): UTF8String = {
5150
this.bytes = bytes
52-
pointer.set(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, bytes.length)
5351
this
5452
}
5553

@@ -59,7 +57,7 @@ final class UTF8String extends Ordered[UTF8String] with Serializable {
5957
* This is only used by Substring() when `start` is negative.
6058
*/
6159
def length(): Int = {
62-
pointer.getLengthInCodePoints
60+
UTF8StringMethods.getLengthInCodePoints(bytes, BYTE_ARRAY_OFFSET, bytes.length)
6361
}
6462

6563
def getBytes: Array[Byte] = {
@@ -107,19 +105,27 @@ final class UTF8String extends Ordered[UTF8String] with Serializable {
107105
}
108106

109107
def startsWith(prefix: UTF8String): Boolean = {
110-
val b = prefix.getBytes
111-
if (b.length > bytes.length) {
112-
return false
113-
}
114-
Arrays.equals(Arrays.copyOfRange(bytes, 0, b.length), b)
108+
val prefixBytes = prefix.getBytes
109+
UTF8StringMethods.startsWith(
110+
bytes,
111+
BYTE_ARRAY_OFFSET,
112+
bytes.length,
113+
prefixBytes,
114+
BYTE_ARRAY_OFFSET,
115+
prefixBytes.length
116+
)
115117
}
116118

117119
def endsWith(suffix: UTF8String): Boolean = {
118-
val b = suffix.getBytes
119-
if (b.length > bytes.length) {
120-
return false
121-
}
122-
Arrays.equals(Arrays.copyOfRange(bytes, bytes.length - b.length, bytes.length), b)
120+
val suffixBytes = suffix.getBytes
121+
UTF8StringMethods.endsWith(
122+
bytes,
123+
BYTE_ARRAY_OFFSET,
124+
bytes.length,
125+
suffixBytes,
126+
BYTE_ARRAY_OFFSET,
127+
suffixBytes.length
128+
)
123129
}
124130

125131
def toUpperCase(): UTF8String = {
@@ -139,13 +145,14 @@ final class UTF8String extends Ordered[UTF8String] with Serializable {
139145
override def clone(): UTF8String = new UTF8String().set(this.bytes)
140146

141147
override def compare(other: UTF8String): Int = {
148+
val otherBytes = other.getBytes
142149
UTF8StringMethods.compare(
143-
pointer.getBaseObject,
144-
pointer.getBaseOffset,
145-
pointer.getLengthInBytes,
146-
other.pointer.getBaseObject,
147-
other.pointer.getBaseOffset,
148-
other.pointer.getLengthInBytes
150+
bytes,
151+
BYTE_ARRAY_OFFSET,
152+
bytes.length,
153+
otherBytes,
154+
BYTE_ARRAY_OFFSET,
155+
otherBytes.length
149156
)
150157
}
151158

@@ -155,7 +162,14 @@ final class UTF8String extends Ordered[UTF8String] with Serializable {
155162

156163
override def equals(other: Any): Boolean = other match {
157164
case s: UTF8String =>
158-
Arrays.equals(bytes, s.getBytes)
165+
val otherBytes = s.getBytes
166+
otherBytes.length == bytes.length && ByteArrayMethods.arrayEquals(
167+
bytes,
168+
BYTE_ARRAY_OFFSET,
169+
otherBytes,
170+
BYTE_ARRAY_OFFSET,
171+
otherBytes.length
172+
)
159173
case s: String =>
160174
// This is only used for Catalyst unit tests
161175
// fail fast

unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,34 @@
2323

2424
public class ByteArrayMethods {
2525

26+
// TODO: there are substantial opportunities for optimization here and we should investigate them.
27+
// See the fast comparisions in Guava's UnsignedBytes, for instance:
28+
// https://code.google.com/p/guava-libraries/source/browse/guava/src/com/google/common/primitives/UnsignedBytes.java
29+
2630
private ByteArrayMethods() {
2731
// Private constructor, since this class only contains static methods.
2832
}
2933

34+
/**
35+
* Optimized equality check for equal-length byte arrays.
36+
* @return true if the arrays are equal, false otherwise
37+
*/
38+
public static boolean arrayEquals(
39+
Object leftBaseObject,
40+
long leftBaseOffset,
41+
Object rightBaseObject,
42+
long rightBaseOffset,
43+
long arrayLengthInBytes) {
44+
for (int i = 0; i < arrayLengthInBytes; i++) {
45+
final byte left =
46+
PlatformDependent.UNSAFE.getByte(leftBaseObject, leftBaseOffset + i);
47+
final byte right =
48+
PlatformDependent.UNSAFE.getByte(rightBaseObject, rightBaseOffset + i);
49+
if (left != right) return false;
50+
}
51+
return true;
52+
}
53+
3054
/**
3155
* Optimized byte array equality check for 8-byte-word-aligned byte arrays.
3256
* @return true if the arrays are equal, false otherwise

unsafe/src/main/java/org/apache/spark/unsafe/string/UTF8StringMethods.java

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.unsafe.string;
1919

2020
import org.apache.spark.unsafe.PlatformDependent;
21+
import org.apache.spark.unsafe.array.ByteArrayMethods;
2122

2223
import java.io.UnsupportedEncodingException;import java.lang.Object;import java.lang.String;
2324

@@ -43,19 +44,57 @@ static long getLengthInBytes(Object baseObject, long baseOffset) {
4344
public static int compare(
4445
Object leftBaseObject,
4546
long leftBaseOffset,
46-
int leftBaseLengthInBytes,
47+
int leftLengthInBytes,
4748
Object rightBaseObject,
4849
long rightBaseOffset,
49-
int rightBaseLengthInBytes) {
50+
int rightLengthInBytes) {
5051
int i = 0;
51-
while (i < leftBaseLengthInBytes && i < rightBaseLengthInBytes) {
52+
while (i < leftLengthInBytes && i < rightLengthInBytes) {
5253
final byte leftByte = PlatformDependent.UNSAFE.getByte(leftBaseObject, leftBaseOffset + i);
5354
final byte rightByte = PlatformDependent.UNSAFE.getByte(rightBaseObject, rightBaseOffset + i);
5455
final int res = leftByte - rightByte;
5556
if (res != 0) return res;
5657
i += 1;
5758
}
58-
return leftBaseLengthInBytes - rightBaseLengthInBytes;
59+
return leftLengthInBytes - rightLengthInBytes;
60+
}
61+
62+
public static boolean startsWith(
63+
Object strBaseObject,
64+
long strBaseOffset,
65+
int strLengthInBytes,
66+
Object prefixBaseObject,
67+
long prefixBaseOffset,
68+
int prefixLengthInBytes) {
69+
if (prefixLengthInBytes > strLengthInBytes) {
70+
return false;
71+
} {
72+
return ByteArrayMethods.arrayEquals(
73+
strBaseObject,
74+
strBaseOffset,
75+
prefixBaseObject,
76+
prefixBaseOffset,
77+
prefixLengthInBytes);
78+
}
79+
}
80+
81+
public static boolean endsWith(
82+
Object strBaseObject,
83+
long strBaseOffset,
84+
int strLengthInBytes,
85+
Object suffixBaseObject,
86+
long suffixBaseOffset,
87+
int suffixLengthInBytes) {
88+
if (suffixLengthInBytes > strLengthInBytes) {
89+
return false;
90+
} {
91+
return ByteArrayMethods.arrayEquals(
92+
strBaseObject,
93+
strBaseOffset + strLengthInBytes - suffixLengthInBytes,
94+
suffixBaseObject,
95+
suffixBaseOffset,
96+
suffixLengthInBytes);
97+
}
5998
}
6099

61100
/**

0 commit comments

Comments
 (0)