Skip to content

Commit 53ba9b7

Browse files
committed
Start prototyping Java Row -> UnsafeRow converters
1 parent 1ff814d commit 53ba9b7

File tree

4 files changed

+255
-4
lines changed

4 files changed

+255
-4
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.apache.spark.sql.types.StructType;
2424
import org.apache.spark.unsafe.PlatformDependent;
2525
import org.apache.spark.unsafe.bitset.BitSetMethods;
26+
import org.apache.spark.unsafe.string.UTF8StringMethods;
2627
import scala.collection.Map;
2728
import scala.collection.Seq;
2829

@@ -62,12 +63,16 @@ private long getFieldOffset(int ordinal) {
6263
return baseOffset + bitSetWidthInBytes + ordinal * 8;
6364
}
6465

66+
public static int calculateBitSetWidthInBytes(int numFields) {
67+
return ((numFields / 64) + ((numFields % 64 == 0 ? 0 : 1))) * 8;
68+
}
69+
6570
public UnsafeRow() { }
6671

6772
public void set(Object baseObject, long baseOffset, int numFields, StructType schema) {
6873
assert numFields >= 0 : "numFields should >= 0";
6974
assert schema == null || schema.fields().length == numFields;
70-
this.bitSetWidthInBytes = ((numFields / 64) + ((numFields % 64 == 0 ? 0 : 1))) * 8;
75+
this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields);
7176
this.baseObject = baseObject;
7277
this.baseOffset = baseOffset;
7378
this.numFields = numFields;
@@ -219,9 +224,11 @@ public double getDouble(int i) {
219224
@Override
220225
public String getString(int i) {
221226
assertIndexIsValid(i);
222-
// TODO
223-
224-
throw new UnsupportedOperationException();
227+
final long offsetToStringSize = getLong(i);
228+
final long stringSizeInBytes =
229+
PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offsetToStringSize);
230+
// TODO: ugly cast; figure out whether we'll support mega long strings
231+
return UTF8StringMethods.toJavaString(baseObject, baseOffset + offsetToStringSize + 8, (int) stringSizeInBytes);
225232
}
226233

227234
@Override
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions
19+
20+
import org.apache.spark.sql.types._
21+
import org.apache.spark.unsafe.PlatformDependent
22+
import org.apache.spark.unsafe.array.ByteArrayMethods
23+
24+
/** Write a column into an UnsafeRow */
25+
private abstract class UnsafeColumnWriter[T] {
26+
/**
27+
* Write a value into an UnsafeRow.
28+
*
29+
* @param value the value to write
30+
* @param columnNumber what column to write it to
31+
* @param row a pointer to the unsafe row
32+
* @param baseObject
33+
* @param baseOffset
34+
* @param appendCursor the offset from the start of the unsafe row to the end of the row;
35+
* used for calculating where variable-length data should be written
36+
* @return the number of variable-length bytes written
37+
*/
38+
def write(
39+
value: T,
40+
columnNumber: Int,
41+
row: UnsafeRow,
42+
baseObject: Object,
43+
baseOffset: Long,
44+
appendCursor: Int): Int
45+
46+
/**
47+
* Return the number of bytes that are needed to write this variable-length value.
48+
*/
49+
def getSize(value: T): Int
50+
}
51+
52+
private object UnsafeColumnWriter {
53+
def forType(dataType: DataType): UnsafeColumnWriter[_] = {
54+
dataType match {
55+
case IntegerType => IntUnsafeColumnWriter
56+
case LongType => LongUnsafeColumnWriter
57+
case StringType => StringUnsafeColumnWriter
58+
case _ => throw new UnsupportedOperationException()
59+
}
60+
}
61+
}
62+
63+
private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter[UTF8String] {
64+
def getSize(value: UTF8String): Int = {
65+
// round to nearest word
66+
val numBytes = value.getBytes.length
67+
8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
68+
}
69+
70+
override def write(
71+
value: UTF8String,
72+
columnNumber: Int,
73+
row: UnsafeRow,
74+
baseObject: Object,
75+
baseOffset: Long,
76+
appendCursor: Int): Int = {
77+
val numBytes = value.getBytes.length
78+
PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + appendCursor, numBytes)
79+
PlatformDependent.copyMemory(
80+
value.getBytes,
81+
PlatformDependent.BYTE_ARRAY_OFFSET,
82+
baseObject,
83+
baseOffset + appendCursor + 8,
84+
numBytes
85+
)
86+
row.setLong(columnNumber, appendCursor)
87+
8 + ((numBytes / 8) + (if (numBytes % 8 == 0) 0 else 1)) * 8
88+
}
89+
}
90+
private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter
91+
92+
private abstract class PrimitiveUnsafeColumnWriter[T] extends UnsafeColumnWriter[T] {
93+
def getSize(value: T): Int = 0
94+
}
95+
96+
private class IntUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter[Int] {
97+
override def write(
98+
value: Int,
99+
columnNumber: Int,
100+
row: UnsafeRow,
101+
baseObject: Object,
102+
baseOffset: Long,
103+
appendCursor: Int): Int = {
104+
row.setInt(columnNumber, value)
105+
0
106+
}
107+
}
108+
private object IntUnsafeColumnWriter extends IntUnsafeColumnWriter
109+
110+
private class LongUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter[Long] {
111+
override def write(
112+
value: Long,
113+
columnNumber: Int,
114+
row: UnsafeRow,
115+
baseObject: Object,
116+
baseOffset: Long,
117+
appendCursor: Int): Int = {
118+
row.setLong(columnNumber, value)
119+
0
120+
}
121+
}
122+
private case object LongUnsafeColumnWriter extends LongUnsafeColumnWriter
123+
124+
125+
class UnsafeRowConverter(fieldTypes: Array[DataType]) {
126+
127+
private[this] val writers: Array[UnsafeColumnWriter[Any]] = {
128+
fieldTypes.map(t => UnsafeColumnWriter.forType(t).asInstanceOf[UnsafeColumnWriter[Any]])
129+
}
130+
131+
def getSizeRequirement(row: Row): Int = {
132+
var fieldNumber = 0
133+
var variableLengthFieldSize: Int = 0
134+
while (fieldNumber < writers.length) {
135+
if (!row.isNullAt(fieldNumber)) {
136+
variableLengthFieldSize += writers(fieldNumber).getSize(row.get(fieldNumber))
137+
138+
}
139+
fieldNumber += 1
140+
}
141+
(8 * fieldTypes.length) + UnsafeRow.calculateBitSetWidthInBytes(fieldTypes.length) + variableLengthFieldSize
142+
}
143+
144+
def writeRow(row: Row, baseObject: Object, baseOffset: Long): Long = {
145+
val unsafeRow = new UnsafeRow()
146+
unsafeRow.set(baseObject, baseOffset, writers.length, null) // TODO: schema?
147+
var fieldNumber = 0
148+
var appendCursor: Int =
149+
(8 * fieldTypes.length) + UnsafeRow.calculateBitSetWidthInBytes(fieldTypes.length)
150+
while (fieldNumber < writers.length) {
151+
if (row.isNullAt(fieldNumber)) {
152+
unsafeRow.setNullAt(fieldNumber)
153+
// TODO: type-specific null value writing?
154+
} else {
155+
appendCursor += writers(fieldNumber).write(
156+
row.get(fieldNumber),
157+
fieldNumber,
158+
unsafeRow,
159+
baseObject,
160+
baseOffset,
161+
appendCursor)
162+
}
163+
fieldNumber += 1
164+
}
165+
appendCursor
166+
}
167+
168+
}
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions
19+
20+
import org.apache.spark.sql.types.{StringType, DataType, LongType, IntegerType}
21+
import org.apache.spark.unsafe.PlatformDependent
22+
import org.apache.spark.unsafe.array.ByteArrayMethods
23+
import org.scalatest.{FunSuite, Matchers}
24+
25+
26+
class UnsafeRowConverterSuite extends FunSuite with Matchers {
27+
28+
test("basic conversion with only primitive types") {
29+
val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType)
30+
val row = new SpecificMutableRow(fieldTypes)
31+
row.setLong(0, 0)
32+
row.setLong(1, 1)
33+
row.setInt(2, 2)
34+
val converter = new UnsafeRowConverter(fieldTypes)
35+
val sizeRequired: Int = converter.getSizeRequirement(row)
36+
sizeRequired should be (8 + (3 * 8))
37+
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
38+
val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET)
39+
numBytesWritten should be (sizeRequired)
40+
val unsafeRow = new UnsafeRow()
41+
unsafeRow.set(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
42+
unsafeRow.getLong(0) should be (0)
43+
unsafeRow.getLong(1) should be (1)
44+
unsafeRow.getInt(2) should be (2)
45+
}
46+
47+
test("basic conversion with primitive and string types") {
48+
val fieldTypes: Array[DataType] = Array(LongType, StringType, StringType)
49+
val row = new SpecificMutableRow(fieldTypes)
50+
row.setLong(0, 0)
51+
row.setString(1, "Hello")
52+
row.setString(2, "World")
53+
val converter = new UnsafeRowConverter(fieldTypes)
54+
val sizeRequired: Int = converter.getSizeRequirement(row)
55+
sizeRequired should be (8 + (8 * 3) +
56+
ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length + 8) +
57+
ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length + 8))
58+
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
59+
val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET)
60+
numBytesWritten should be (sizeRequired)
61+
val unsafeRow = new UnsafeRow()
62+
unsafeRow.set(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
63+
unsafeRow.getLong(0) should be (0)
64+
unsafeRow.getString(1) should be ("Hello")
65+
unsafeRow.getString(2) should be ("World")
66+
}
67+
}

unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,15 @@ private ByteArrayMethods() {
3131
// Private constructor, since this class only contains static methods.
3232
}
3333

34+
public static int roundNumberOfBytesToNearestWord(int numBytes) {
35+
int remainder = numBytes % 8;
36+
if (remainder == 0) {
37+
return numBytes;
38+
} else {
39+
return numBytes + (8 - remainder);
40+
}
41+
}
42+
3443
public static void zeroBytes(
3544
Object baseObject,
3645
long baseOffset,

0 commit comments

Comments
 (0)