Skip to content

Commit ab68e08

Browse files
committed
Begin merging the UTF8String implementations.
1 parent 480a74a commit ab68e08

File tree

5 files changed

+119
-51
lines changed

5 files changed

+119
-51
lines changed

core/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,11 @@
9191
<artifactId>spark-network-shuffle_${scala.binary.version}</artifactId>
9292
<version>${project.version}</version>
9393
</dependency>
94+
<dependency>
95+
<groupId>org.apache.spark</groupId>
96+
<artifactId>spark-unsafe_${scala.binary.version}</artifactId>
97+
<version>${project.version}</version>
98+
</dependency>
9499
<dependency>
95100
<groupId>net.java.dev.jets3t</groupId>
96101
<artifactId>jets3t</artifactId>

sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala

Lines changed: 18 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ package org.apache.spark.sql.types
1919

2020
import java.util.Arrays
2121

22+
import org.apache.spark.unsafe.PlatformDependent
23+
import org.apache.spark.unsafe.string.{UTF8StringPointer, UTF8StringMethods}
24+
2225
/**
2326
* A UTF-8 String, as internal representation of StringType in SparkSQL
2427
*
@@ -32,45 +35,31 @@ final class UTF8String extends Ordered[UTF8String] with Serializable {
3235

3336
private[this] var bytes: Array[Byte] = _
3437

38+
private val pointer: UTF8StringPointer = new UTF8StringPointer
39+
3540
/**
3641
* Update the UTF8String with String.
3742
*/
3843
def set(str: String): UTF8String = {
39-
bytes = str.getBytes("utf-8")
40-
this
44+
set(str.getBytes("utf-8"))
4145
}
4246

4347
/**
4448
* Update the UTF8String with Array[Byte], which should be encoded in UTF-8
4549
*/
4650
def set(bytes: Array[Byte]): UTF8String = {
4751
this.bytes = bytes
52+
pointer.set(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, bytes.length)
4853
this
4954
}
5055

51-
/**
52-
* Return the number of bytes for a code point with the first byte as `b`
53-
* @param b The first byte of a code point
54-
*/
55-
@inline
56-
private[this] def numOfBytes(b: Byte): Int = {
57-
val offset = (b & 0xFF) - 192
58-
if (offset >= 0) UTF8String.bytesOfCodePointInUTF8(offset) else 1
59-
}
60-
6156
/**
6257
* Return the number of code points in it.
6358
*
6459
* This is only used by Substring() when `start` is negative.
6560
*/
6661
def length(): Int = {
67-
var len = 0
68-
var i: Int = 0
69-
while (i < bytes.length) {
70-
i += numOfBytes(bytes(i))
71-
len += 1
72-
}
73-
len
62+
pointer.getLengthInCodePoints
7463
}
7564

7665
def getBytes: Array[Byte] = {
@@ -90,12 +79,12 @@ final class UTF8String extends Ordered[UTF8String] with Serializable {
9079
var c = 0
9180
var i: Int = 0
9281
while (c < start && i < bytes.length) {
93-
i += numOfBytes(bytes(i))
82+
i += UTF8StringMethods.numOfBytes(bytes(i))
9483
c += 1
9584
}
9685
var j = i
9786
while (c < until && j < bytes.length) {
98-
j += numOfBytes(bytes(j))
87+
j += UTF8StringMethods.numOfBytes(bytes(j))
9988
c += 1
10089
}
10190
UTF8String(Arrays.copyOfRange(bytes, i, j))
@@ -150,14 +139,14 @@ final class UTF8String extends Ordered[UTF8String] with Serializable {
150139
override def clone(): UTF8String = new UTF8String().set(this.bytes)
151140

152141
override def compare(other: UTF8String): Int = {
153-
var i: Int = 0
154-
val b = other.getBytes
155-
while (i < bytes.length && i < b.length) {
156-
val res = bytes(i).compareTo(b(i))
157-
if (res != 0) return res
158-
i += 1
159-
}
160-
bytes.length - b.length
142+
UTF8StringMethods.compare(
143+
pointer.getBaseObject,
144+
pointer.getBaseOffset,
145+
pointer.getLengthInBytes,
146+
other.pointer.getBaseObject,
147+
other.pointer.getBaseOffset,
148+
other.pointer.getLengthInBytes
149+
)
161150
}
162151

163152
override def compareTo(other: UTF8String): Int = {
@@ -181,14 +170,6 @@ final class UTF8String extends Ordered[UTF8String] with Serializable {
181170
}
182171

183172
object UTF8String {
184-
// number of tailing bytes in a UTF8 sequence for a code point
185-
// see http://en.wikipedia.org/wiki/UTF-8, 192-256 of Byte 1
186-
private[types] val bytesOfCodePointInUTF8: Array[Int] = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
187-
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
188-
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
189-
4, 4, 4, 4, 4, 4, 4, 4,
190-
5, 5, 5, 5,
191-
6, 6, 6, 6)
192173

193174
/**
194175
* Create a UTF-8 String from String

unsafe/src/main/java/org/apache/spark/unsafe/string/UTF8StringMethods.java

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,44 @@ static long getLengthInBytes(Object baseObject, long baseOffset) {
4040
return PlatformDependent.UNSAFE.getLong(baseObject, baseOffset);
4141
}
4242

43-
public static String toJavaString(Object baseObject, long baseOffset) {
44-
final long lengthInBytes = getLengthInBytes(baseObject, baseOffset);
43+
public static int compare(
44+
Object leftBaseObject,
45+
long leftBaseOffset,
46+
int leftBaseLengthInBytes,
47+
Object rightBaseObject,
48+
long rightBaseOffset,
49+
int rightBaseLengthInBytes) {
50+
int i = 0;
51+
while (i < leftBaseLengthInBytes && i < rightBaseLengthInBytes) {
52+
final byte leftByte = PlatformDependent.UNSAFE.getByte(leftBaseObject, leftBaseOffset + i);
53+
final byte rightByte = PlatformDependent.UNSAFE.getByte(rightBaseObject, rightBaseOffset + i);
54+
final int res = leftByte - rightByte;
55+
if (res != 0) return res;
56+
i += 1;
57+
}
58+
return leftBaseLengthInBytes - rightBaseLengthInBytes;
59+
}
60+
61+
/**
62+
* Return the number of code points in a string.
63+
*
64+
* This is only used by Substring() when `start` is negative.
65+
*/
66+
public static int getLengthInCodePoints(Object baseObject, long baseOffset, int lengthInBytes) {
67+
int len = 0;
68+
int i = 0;
69+
while (i < lengthInBytes) {
70+
i += numOfBytes(PlatformDependent.UNSAFE.getByte(baseObject, baseOffset + i));
71+
len += 1;
72+
}
73+
return len;
74+
}
75+
76+
public static String toJavaString(Object baseObject, long baseOffset, int lengthInBytes) {
4577
final byte[] bytes = new byte[(int) lengthInBytes];
4678
PlatformDependent.UNSAFE.copyMemory(
4779
baseObject,
48-
baseOffset + 8, // skip over the length
80+
baseOffset,
4981
bytes,
5082
PlatformDependent.BYTE_ARRAY_OFFSET,
5183
lengthInBytes
@@ -67,15 +99,40 @@ public static String toJavaString(Object baseObject, long baseOffset) {
6799
public static long createFromJavaString(Object baseObject, long baseOffset, String str) {
68100
final byte[] strBytes = str.getBytes();
69101
final long strLengthInBytes = strBytes.length;
70-
PlatformDependent.UNSAFE.putLong(baseObject, baseOffset, strLengthInBytes);
71102
PlatformDependent.copyMemory(
72103
strBytes,
73104
PlatformDependent.BYTE_ARRAY_OFFSET,
74105
baseObject,
75-
baseOffset + 8,
106+
baseOffset,
76107
strLengthInBytes
77108
);
78-
return (8 + strLengthInBytes);
109+
return strLengthInBytes;
79110
}
80111

112+
/**
113+
* Return the number of bytes for a code point with the first byte as `b`
114+
* @param b The first byte of a code point
115+
*/
116+
public static int numOfBytes(byte b) {
117+
final int offset = (b & 0xFF) - 192;
118+
if (offset >= 0) {
119+
return bytesOfCodePointInUTF8[offset];
120+
} else {
121+
return 1;
122+
}
123+
}
124+
125+
/**
126+
* number of tailing bytes in a UTF8 sequence for a code point
127+
* see http://en.wikipedia.org/wiki/UTF-8, 192-256 of Byte 1
128+
*/
129+
private static int[] bytesOfCodePointInUTF8 = new int[] {
130+
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
131+
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
132+
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
133+
4, 4, 4, 4, 4, 4, 4, 4,
134+
5, 5, 5, 5,
135+
6, 6, 6, 6
136+
};
137+
81138
}

unsafe/src/main/java/org/apache/spark/unsafe/string/UTF8StringPointer.java

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,39 @@
1717

1818
package org.apache.spark.unsafe.string;
1919

20-
import org.apache.spark.unsafe.memory.MemoryLocation;
20+
import javax.annotation.Nullable;
2121

2222
/**
2323
* A pointer to UTF8String data.
2424
*/
25-
public class UTF8StringPointer extends MemoryLocation {
25+
public class UTF8StringPointer {
2626

27-
public long getLengthInBytes() { return UTF8StringMethods.getLengthInBytes(obj, offset); }
27+
@Nullable
28+
protected Object obj;
29+
protected long offset;
30+
protected int lengthInBytes;
2831

29-
public String toJavaString() { return UTF8StringMethods.toJavaString(obj, offset); }
32+
public UTF8StringPointer() { }
33+
34+
public void set(Object obj, long offset, int lengthInBytes) {
35+
this.obj = obj;
36+
this.offset = offset;
37+
this.lengthInBytes = lengthInBytes;
38+
}
39+
40+
public int getLengthInCodePoints() {
41+
return UTF8StringMethods.getLengthInCodePoints(obj, offset, lengthInBytes);
42+
}
43+
44+
public int getLengthInBytes() { return lengthInBytes; }
45+
46+
public Object getBaseObject() { return obj; }
47+
48+
public long getBaseOffset() { return offset; }
49+
50+
public String toJavaString() {
51+
return UTF8StringMethods.toJavaString(obj, offset, lengthInBytes);
52+
}
3053

3154
@Override public String toString() { return toJavaString(); }
3255
}

unsafe/src/test/java/org/apache/spark/unsafe/string/TestUTF8String.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,13 @@ public void toStringTest() {
3232
final byte[] javaStrBytes = javaStr.getBytes();
3333
final int paddedSizeInWords = javaStrBytes.length / 8 + (javaStrBytes.length % 8 == 0 ? 0 : 1);
3434
final MemoryLocation memory = MemoryBlock.fromLongArray(new long[paddedSizeInWords]);
35-
final long bytesWritten =
36-
UTF8StringMethods.createFromJavaString(memory.getBaseObject(), memory.getBaseOffset(), javaStr);
37-
Assert.assertEquals(8 + javaStrBytes.length, bytesWritten);
35+
final long bytesWritten = UTF8StringMethods.createFromJavaString(
36+
memory.getBaseObject(),
37+
memory.getBaseOffset(),
38+
javaStr);
39+
Assert.assertEquals(javaStrBytes.length, bytesWritten);
3840
final UTF8StringPointer utf8String = new UTF8StringPointer();
39-
utf8String.setObjAndOffset(memory.getBaseObject(), memory.getBaseOffset());
41+
utf8String.set(memory.getBaseObject(), memory.getBaseOffset(), bytesWritten);
4042
Assert.assertEquals(javaStr, utf8String.toJavaString());
4143
}
4244
}

0 commit comments

Comments
 (0)