Skip to content

Commit 54b0d1e

Browse files
pengbocloud-fan
authored andcommitted
[SPARK-27416][SQL] UnsafeMapData & UnsafeArrayData Kryo serialization …
## What changes were proposed in this pull request? Finish the rest work of #24317, #9030 a. Implement Kryo serialization for UnsafeArrayData b. fix UnsafeMapData Java/Kryo Serialization issue when two machines have different Oops size c. Move the duplicate code "getBytes()" to Utils. ## How was this patch tested? According Units has been added & tested Closes #24357 from pengbo/SPARK-27416_new. Authored-by: pengbo <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 61feb16 commit 54b0d1e

File tree

7 files changed

+197
-32
lines changed

7 files changed

+197
-32
lines changed

core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,10 @@ class KryoSerializer(conf: SparkConf)
213213
// We can't load those class directly in order to avoid unnecessary jar dependencies.
214214
// We load them safely, ignore it if the class not found.
215215
Seq(
216+
"org.apache.spark.sql.catalyst.expressions.UnsafeRow",
217+
"org.apache.spark.sql.catalyst.expressions.UnsafeArrayData",
218+
"org.apache.spark.sql.catalyst.expressions.UnsafeMapData",
219+
216220
"org.apache.spark.ml.attribute.Attribute",
217221
"org.apache.spark.ml.attribute.AttributeGroup",
218222
"org.apache.spark.ml.attribute.BinaryAttribute",

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@
2525
import java.math.BigInteger;
2626
import java.nio.ByteBuffer;
2727

28+
import com.esotericsoftware.kryo.Kryo;
29+
import com.esotericsoftware.kryo.KryoSerializable;
30+
import com.esotericsoftware.kryo.io.Input;
31+
import com.esotericsoftware.kryo.io.Output;
32+
2833
import org.apache.spark.sql.catalyst.util.ArrayData;
2934
import org.apache.spark.sql.types.*;
3035
import org.apache.spark.unsafe.Platform;
@@ -58,7 +63,7 @@
5863
* Instances of `UnsafeArrayData` act as pointers to row data stored in this format.
5964
*/
6065

61-
public final class UnsafeArrayData extends ArrayData implements Externalizable {
66+
public final class UnsafeArrayData extends ArrayData implements Externalizable, KryoSerializable {
6267
public static int calculateHeaderPortionInBytes(int numFields) {
6368
return (int)calculateHeaderPortionInBytes((long)numFields);
6469
}
@@ -492,22 +497,9 @@ public static UnsafeArrayData fromPrimitiveArray(double[] arr) {
492497
return fromPrimitiveArray(arr, Platform.DOUBLE_ARRAY_OFFSET, arr.length, 8);
493498
}
494499

495-
496-
public byte[] getBytes() {
497-
if (baseObject instanceof byte[]
498-
&& baseOffset == Platform.BYTE_ARRAY_OFFSET
499-
&& (((byte[]) baseObject).length == sizeInBytes)) {
500-
return (byte[]) baseObject;
501-
} else {
502-
byte[] bytes = new byte[sizeInBytes];
503-
Platform.copyMemory(baseObject, baseOffset, bytes, Platform.BYTE_ARRAY_OFFSET, sizeInBytes);
504-
return bytes;
505-
}
506-
}
507-
508500
@Override
509501
public void writeExternal(ObjectOutput out) throws IOException {
510-
byte[] bytes = getBytes();
502+
byte[] bytes = UnsafeDataUtils.getBytes(baseObject, baseOffset, sizeInBytes);
511503
out.writeInt(bytes.length);
512504
out.writeInt(this.numElements);
513505
out.write(bytes);
@@ -522,4 +514,22 @@ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundExcept
522514
this.baseObject = new byte[sizeInBytes];
523515
in.readFully((byte[]) baseObject);
524516
}
517+
518+
@Override
519+
public void write(Kryo kryo, Output output) {
520+
byte[] bytes = UnsafeDataUtils.getBytes(baseObject, baseOffset, sizeInBytes);
521+
output.writeInt(bytes.length);
522+
output.writeInt(this.numElements);
523+
output.write(bytes);
524+
}
525+
526+
@Override
527+
public void read(Kryo kryo, Input input) {
528+
this.baseOffset = BYTE_ARRAY_OFFSET;
529+
this.sizeInBytes = input.readInt();
530+
this.numElements = input.readInt();
531+
this.elementOffset = baseOffset + calculateHeaderPortionInBytes(this.numElements);
532+
this.baseObject = new byte[sizeInBytes];
533+
input.read((byte[]) baseObject);
534+
}
525535
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql.catalyst.expressions;
18+
19+
import org.apache.spark.unsafe.Platform;
20+
21+
/**
22+
* General utilities available for unsafe data
23+
*/
24+
final class UnsafeDataUtils {
25+
26+
private UnsafeDataUtils() {
27+
}
28+
29+
public static byte[] getBytes(Object baseObject, long baseOffset, int sizeInBytes) {
30+
if (baseObject instanceof byte[]
31+
&& baseOffset == Platform.BYTE_ARRAY_OFFSET
32+
&& (((byte[]) baseObject).length == sizeInBytes)) {
33+
return (byte[]) baseObject;
34+
}
35+
byte[] bytes = new byte[sizeInBytes];
36+
Platform.copyMemory(baseObject, baseOffset, bytes, Platform.BYTE_ARRAY_OFFSET,
37+
sizeInBytes);
38+
return bytes;
39+
}
40+
}

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,22 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions;
1919

20+
import java.io.Externalizable;
21+
import java.io.IOException;
22+
import java.io.ObjectInput;
23+
import java.io.ObjectOutput;
2024
import java.nio.ByteBuffer;
2125

26+
import com.esotericsoftware.kryo.Kryo;
27+
import com.esotericsoftware.kryo.KryoSerializable;
28+
import com.esotericsoftware.kryo.io.Input;
29+
import com.esotericsoftware.kryo.io.Output;
30+
2231
import org.apache.spark.sql.catalyst.util.MapData;
2332
import org.apache.spark.unsafe.Platform;
2433

34+
import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET;
35+
2536
/**
2637
* An Unsafe implementation of Map which is backed by raw memory instead of Java objects.
2738
*
@@ -33,7 +44,7 @@
3344
* elements, otherwise the behavior is undefined.
3445
*/
3546
// TODO: Use a more efficient format which doesn't depend on unsafe array.
36-
public final class UnsafeMapData extends MapData {
47+
public final class UnsafeMapData extends MapData implements Externalizable, KryoSerializable {
3748

3849
private Object baseObject;
3950
private long baseOffset;
@@ -123,4 +134,36 @@ public UnsafeMapData copy() {
123134
mapCopy.pointTo(mapDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes);
124135
return mapCopy;
125136
}
137+
138+
@Override
139+
public void writeExternal(ObjectOutput out) throws IOException {
140+
byte[] bytes = UnsafeDataUtils.getBytes(baseObject, baseOffset, sizeInBytes);
141+
out.writeInt(bytes.length);
142+
out.write(bytes);
143+
}
144+
145+
@Override
146+
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
147+
this.baseOffset = BYTE_ARRAY_OFFSET;
148+
this.sizeInBytes = in.readInt();
149+
this.baseObject = new byte[sizeInBytes];
150+
in.readFully((byte[]) baseObject);
151+
pointTo(baseObject, baseOffset, sizeInBytes);
152+
}
153+
154+
@Override
155+
public void write(Kryo kryo, Output output) {
156+
byte[] bytes = UnsafeDataUtils.getBytes(baseObject, baseOffset, sizeInBytes);
157+
output.writeInt(bytes.length);
158+
output.write(bytes);
159+
}
160+
161+
@Override
162+
public void read(Kryo kryo, Input input) {
163+
this.baseOffset = BYTE_ARRAY_OFFSET;
164+
this.sizeInBytes = input.readInt();
165+
this.baseObject = new byte[sizeInBytes];
166+
input.read((byte[]) baseObject);
167+
pointTo(baseObject, baseOffset, sizeInBytes);
168+
}
126169
}

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -541,14 +541,7 @@ public boolean equals(Object other) {
541541
* Returns the underlying bytes for this UnsafeRow.
542542
*/
543543
public byte[] getBytes() {
544-
if (baseObject instanceof byte[] && baseOffset == Platform.BYTE_ARRAY_OFFSET
545-
&& (((byte[]) baseObject).length == sizeInBytes)) {
546-
return (byte[]) baseObject;
547-
} else {
548-
byte[] bytes = new byte[sizeInBytes];
549-
Platform.copyMemory(baseObject, baseOffset, bytes, Platform.BYTE_ARRAY_OFFSET, sizeInBytes);
550-
return bytes;
551-
}
544+
return UnsafeDataUtils.getBytes(baseObject, baseOffset, sizeInBytes);
552545
}
553546

554547
// This is for debugging

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.util
2020
import java.time.ZoneId
2121

2222
import org.apache.spark.{SparkConf, SparkFunSuite}
23-
import org.apache.spark.serializer.JavaSerializer
23+
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
2424
import org.apache.spark.sql.Row
2525
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
2626
import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData
@@ -60,6 +60,16 @@ class UnsafeArraySuite extends SparkFunSuite {
6060
val doubleMultiDimArray = Array(
6161
Array(1.1, 11.1), Array(2.2, 22.2, 222.2), Array(3.3, 33.3, 333.3, 3333.3))
6262

63+
val serialArray = {
64+
val offset = 32
65+
val data = new Array[Byte](1024)
66+
Platform.putLong(data, offset, 1)
67+
val arrayData = new UnsafeArrayData()
68+
arrayData.pointTo(data, offset, data.length)
69+
arrayData.setLong(0, 19285)
70+
arrayData
71+
}
72+
6373
test("read array") {
6474
val unsafeBoolean = ExpressionEncoder[Array[Boolean]].resolveAndBind().
6575
toRow(booleanArray).getArray(0)
@@ -214,14 +224,15 @@ class UnsafeArraySuite extends SparkFunSuite {
214224
}
215225

216226
test("unsafe java serialization") {
217-
val offset = 32
218-
val data = new Array[Byte](1024)
219-
Platform.putLong(data, offset, 1)
220-
val arrayData = new UnsafeArrayData()
221-
arrayData.pointTo(data, offset, data.length)
222-
arrayData.setLong(0, 19285)
223227
val ser = new JavaSerializer(new SparkConf).newInstance()
224-
val arrayDataSer = ser.deserialize[UnsafeArrayData](ser.serialize(arrayData))
228+
val arrayDataSer = ser.deserialize[UnsafeArrayData](ser.serialize(serialArray))
229+
assert(arrayDataSer.getLong(0) == 19285)
230+
assert(arrayDataSer.getBaseObject.asInstanceOf[Array[Byte]].length == 1024)
231+
}
232+
233+
test("unsafe Kryo serialization") {
234+
val ser = new KryoSerializer(new SparkConf).newInstance()
235+
val arrayDataSer = ser.deserialize[UnsafeArrayData](ser.serialize(serialArray))
225236
assert(arrayDataSer.getLong(0) == 19285)
226237
assert(arrayDataSer.getBaseObject.asInstanceOf[Array[Byte]].length == 1024)
227238
}
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.util
19+
20+
import org.apache.spark.{SparkConf, SparkFunSuite}
21+
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
22+
import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeMapData}
23+
import org.apache.spark.unsafe.Platform
24+
25+
class UnsafeMapSuite extends SparkFunSuite {
26+
27+
val unsafeMapData = {
28+
val offset = 32
29+
val keyArraySize = 256
30+
val baseObject = new Array[Byte](1024)
31+
Platform.putLong(baseObject, offset, keyArraySize)
32+
33+
val unsafeMap = new UnsafeMapData
34+
Platform.putLong(baseObject, offset + 8, 1)
35+
val keyArray = new UnsafeArrayData()
36+
keyArray.pointTo(baseObject, offset + 8, keyArraySize)
37+
keyArray.setLong(0, 19285)
38+
39+
val valueArray = new UnsafeArrayData()
40+
Platform.putLong(baseObject, offset + 8 + keyArray.getSizeInBytes, 1)
41+
valueArray.pointTo(baseObject, offset + 8 + keyArray.getSizeInBytes, keyArraySize)
42+
valueArray.setLong(0, 19286)
43+
unsafeMap.pointTo(baseObject, offset, baseObject.length)
44+
unsafeMap
45+
}
46+
47+
test("unsafe java serialization") {
48+
val ser = new JavaSerializer(new SparkConf).newInstance()
49+
val mapDataSer = ser.deserialize[UnsafeMapData](ser.serialize(unsafeMapData))
50+
assert(mapDataSer.numElements() == 1)
51+
assert(mapDataSer.keyArray().getInt(0) == 19285)
52+
assert(mapDataSer.valueArray().getInt(0) == 19286)
53+
assert(mapDataSer.getBaseObject.asInstanceOf[Array[Byte]].length == 1024)
54+
}
55+
56+
test("unsafe Kryo serialization") {
57+
val ser = new KryoSerializer(new SparkConf).newInstance()
58+
val mapDataSer = ser.deserialize[UnsafeMapData](ser.serialize(unsafeMapData))
59+
assert(mapDataSer.numElements() == 1)
60+
assert(mapDataSer.keyArray().getInt(0) == 19285)
61+
assert(mapDataSer.valueArray().getInt(0) == 19286)
62+
assert(mapDataSer.getBaseObject.asInstanceOf[Array[Byte]].length == 1024)
63+
}
64+
}

0 commit comments

Comments
 (0)