diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 424a3ed9bb5b0..7271004a68859 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -182,6 +182,10 @@ class KryoSerializer(conf: SparkConf) // We can't load those class directly in order to avoid unnecessary jar dependencies. // We load them safely, ignore it if the class not found. Seq( + "org.apache.spark.sql.catalyst.expressions.UnsafeRow", + "org.apache.spark.sql.catalyst.expressions.UnsafeArrayData", + "org.apache.spark.sql.catalyst.expressions.UnsafeMapData", + "org.apache.spark.ml.feature.Instance", "org.apache.spark.ml.feature.LabeledPoint", "org.apache.spark.ml.feature.OffsetInstance", diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 8e4ecf3f910a3..d3a8da4f3145b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -25,6 +25,11 @@ import java.math.BigInteger; import java.nio.ByteBuffer; +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.KryoSerializable; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; + import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; @@ -58,7 +63,7 @@ * Instances of `UnsafeArrayData` act as pointers to row data stored in this format. */ -public final class UnsafeArrayData extends ArrayData implements Externalizable { +public final class UnsafeArrayData extends ArrayData implements Externalizable, KryoSerializable { public static int calculateHeaderPortionInBytes(int numFields) { return (int)calculateHeaderPortionInBytes((long)numFields); @@ -530,22 +535,9 @@ public static UnsafeArrayData fromPrimitiveArray(double[] arr) { return fromPrimitiveArray(arr, Platform.DOUBLE_ARRAY_OFFSET, arr.length, 8); } - - public byte[] getBytes() { - if (baseObject instanceof byte[] - && baseOffset == Platform.BYTE_ARRAY_OFFSET - && (((byte[]) baseObject).length == sizeInBytes)) { - return (byte[]) baseObject; - } else { - byte[] bytes = new byte[sizeInBytes]; - Platform.copyMemory(baseObject, baseOffset, bytes, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); - return bytes; - } - } - @Override public void writeExternal(ObjectOutput out) throws IOException { - byte[] bytes = getBytes(); + byte[] bytes = UnsafeDataUtils.getBytes(baseObject, baseOffset, sizeInBytes); out.writeInt(bytes.length); out.writeInt(this.numElements); out.write(bytes); @@ -560,4 +552,22 @@ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundExcept this.baseObject = new byte[sizeInBytes]; in.readFully((byte[]) baseObject); } + + @Override + public void write(Kryo kryo, Output output) { + byte[] bytes = UnsafeDataUtils.getBytes(baseObject, baseOffset, sizeInBytes); + output.writeInt(bytes.length); + output.writeInt(this.numElements); + output.write(bytes); + } + + @Override + public void read(Kryo kryo, Input input) { + this.baseOffset = BYTE_ARRAY_OFFSET; + this.sizeInBytes = input.readInt(); + this.numElements = input.readInt(); + this.elementOffset = baseOffset + calculateHeaderPortionInBytes(this.numElements); + this.baseObject = new byte[sizeInBytes]; + input.read((byte[]) baseObject); + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeDataUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeDataUtils.java new file mode 100644 index 0000000000000..9b600192ac250 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeDataUtils.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions; + +import org.apache.spark.unsafe.Platform; + +/** + * General utilities available for unsafe data + */ +final class UnsafeDataUtils { + + private UnsafeDataUtils() { + } + + public static byte[] getBytes(Object baseObject, long baseOffset, int sizeInBytes) { + if (baseObject instanceof byte[] + && baseOffset == Platform.BYTE_ARRAY_OFFSET + && (((byte[]) baseObject).length == sizeInBytes)) { + return (byte[]) baseObject; + } + byte[] bytes = new byte[sizeInBytes]; + Platform.copyMemory(baseObject, baseOffset, bytes, Platform.BYTE_ARRAY_OFFSET, + sizeInBytes); + return bytes; + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java index f17441dfccb6d..e07ac2fcf81b9 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java @@ -17,11 +17,22 @@ package org.apache.spark.sql.catalyst.expressions; +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; import java.nio.ByteBuffer; +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.KryoSerializable; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; + import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.unsafe.Platform; +import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET; + /** * An Unsafe implementation of Map which is backed by raw memory instead of Java objects. * @@ -30,7 +41,7 @@ * [unsafe key array numBytes] [unsafe key array] [unsafe value array] */ // TODO: Use a more efficient format which doesn't depend on unsafe array. -public final class UnsafeMapData extends MapData { +public final class UnsafeMapData extends MapData implements Externalizable, KryoSerializable { private Object baseObject; private long baseOffset; @@ -120,4 +131,36 @@ public UnsafeMapData copy() { mapCopy.pointTo(mapDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); return mapCopy; } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + byte[] bytes = UnsafeDataUtils.getBytes(baseObject, baseOffset, sizeInBytes); + out.writeInt(bytes.length); + out.write(bytes); + } + + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + this.baseOffset = BYTE_ARRAY_OFFSET; + this.sizeInBytes = in.readInt(); + this.baseObject = new byte[sizeInBytes]; + in.readFully((byte[]) baseObject); + pointTo(baseObject, baseOffset, sizeInBytes); + } + + @Override + public void write(Kryo kryo, Output output) { + byte[] bytes = UnsafeDataUtils.getBytes(baseObject, baseOffset, sizeInBytes); + output.writeInt(bytes.length); + output.write(bytes); + } + + @Override + public void read(Kryo kryo, Input input) { + this.baseOffset = BYTE_ARRAY_OFFSET; + this.sizeInBytes = input.readInt(); + this.baseObject = new byte[sizeInBytes]; + input.read((byte[]) baseObject); + pointTo(baseObject, baseOffset, sizeInBytes); + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index a76e6ef8c91c1..ee2b67a2d1885 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -586,14 +586,7 @@ public boolean equals(Object other) { * Returns the underlying bytes for this UnsafeRow. */ public byte[] getBytes() { - if (baseObject instanceof byte[] && baseOffset == Platform.BYTE_ARRAY_OFFSET - && (((byte[]) baseObject).length == sizeInBytes)) { - return (byte[]) baseObject; - } else { - byte[] bytes = new byte[sizeInBytes]; - Platform.copyMemory(baseObject, baseOffset, bytes, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); - return bytes; - } + return UnsafeDataUtils.getBytes(baseObject, baseOffset, sizeInBytes); } // This is for debugging diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala index 818b2bdfce93e..f485320b395cd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.util import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData @@ -54,6 +54,16 @@ class UnsafeArraySuite extends SparkFunSuite { val doubleMultiDimArray = Array( Array(1.1, 11.1), Array(2.2, 22.2, 222.2), Array(3.3, 33.3, 333.3, 3333.3)) + val serialArray = { + val offset = 32 + val data = new Array[Byte](1024) + Platform.putLong(data, offset, 1) + val arrayData = new UnsafeArrayData() + arrayData.pointTo(data, offset, data.length) + arrayData.setLong(0, 19285) + arrayData + } + test("read array") { val unsafeBoolean = ExpressionEncoder[Array[Boolean]].resolveAndBind(). toRow(booleanArray).getArray(0) @@ -208,14 +218,15 @@ class UnsafeArraySuite extends SparkFunSuite { } test("unsafe java serialization") { - val offset = 32 - val data = new Array[Byte](1024) - Platform.putLong(data, offset, 1) - val arrayData = new UnsafeArrayData() - arrayData.pointTo(data, offset, data.length) - arrayData.setLong(0, 19285) val ser = new JavaSerializer(new SparkConf).newInstance() - val arrayDataSer = ser.deserialize[UnsafeArrayData](ser.serialize(arrayData)) + val arrayDataSer = ser.deserialize[UnsafeArrayData](ser.serialize(serialArray)) + assert(arrayDataSer.getLong(0) == 19285) + assert(arrayDataSer.getBaseObject.asInstanceOf[Array[Byte]].length == 1024) + } + + test("unsafe Kryo serialization") { + val ser = new KryoSerializer(new SparkConf).newInstance() + val arrayDataSer = ser.deserialize[UnsafeArrayData](ser.serialize(serialArray)) assert(arrayDataSer.getLong(0) == 19285) assert(arrayDataSer.getBaseObject.asInstanceOf[Array[Byte]].length == 1024) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeMapSuite.scala new file mode 100644 index 0000000000000..ebc88612be22a --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeMapSuite.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} +import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeMapData} +import org.apache.spark.unsafe.Platform + +class UnsafeMapSuite extends SparkFunSuite { + + val unsafeMapData = { + val offset = 32 + val keyArraySize = 256 + val baseObject = new Array[Byte](1024) + Platform.putLong(baseObject, offset, keyArraySize) + + val unsafeMap = new UnsafeMapData + Platform.putLong(baseObject, offset + 8, 1) + val keyArray = new UnsafeArrayData() + keyArray.pointTo(baseObject, offset + 8, keyArraySize) + keyArray.setLong(0, 19285) + + val valueArray = new UnsafeArrayData() + Platform.putLong(baseObject, offset + 8 + keyArray.getSizeInBytes, 1) + valueArray.pointTo(baseObject, offset + 8 + keyArray.getSizeInBytes, keyArraySize) + valueArray.setLong(0, 19286) + unsafeMap.pointTo(baseObject, offset, baseObject.length) + unsafeMap + } + + test("unsafe java serialization") { + val ser = new JavaSerializer(new SparkConf).newInstance() + val mapDataSer = ser.deserialize[UnsafeMapData](ser.serialize(unsafeMapData)) + assert(mapDataSer.numElements() == 1) + assert(mapDataSer.keyArray().getInt(0) == 19285) + assert(mapDataSer.valueArray().getInt(0) == 19286) + assert(mapDataSer.getBaseObject.asInstanceOf[Array[Byte]].length == 1024) + } + + test("unsafe Kryo serialization") { + val ser = new KryoSerializer(new SparkConf).newInstance() + val mapDataSer = ser.deserialize[UnsafeMapData](ser.serialize(unsafeMapData)) + assert(mapDataSer.numElements() == 1) + assert(mapDataSer.keyArray().getInt(0) == 19285) + assert(mapDataSer.valueArray().getInt(0) == 19286) + assert(mapDataSer.getBaseObject.asInstanceOf[Array[Byte]].length == 1024) + } +}