diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 04fc98659b64..4ff0838ac117 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -17,6 +17,10 @@ package org.apache.spark.sql.catalyst.expressions; +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; import java.math.BigDecimal; import java.math.BigInteger; import java.nio.ByteBuffer; @@ -30,6 +34,8 @@ import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; +import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET; + /** * An Unsafe implementation of Array which is backed by raw memory instead of Java objects. * @@ -52,7 +58,7 @@ * Instances of `UnsafeArrayData` act as pointers to row data stored in this format. */ -public final class UnsafeArrayData extends ArrayData { +public final class UnsafeArrayData extends ArrayData implements Externalizable { public static int calculateHeaderPortionInBytes(int numFields) { return (int)calculateHeaderPortionInBytes((long)numFields); } @@ -485,4 +491,35 @@ public static UnsafeArrayData fromPrimitiveArray(float[] arr) { public static UnsafeArrayData fromPrimitiveArray(double[] arr) { return fromPrimitiveArray(arr, Platform.DOUBLE_ARRAY_OFFSET, arr.length, 8); } + + + public byte[] getBytes() { + if (baseObject instanceof byte[] + && baseOffset == Platform.BYTE_ARRAY_OFFSET + && (((byte[]) baseObject).length == sizeInBytes)) { + return (byte[]) baseObject; + } else { + byte[] bytes = new byte[sizeInBytes]; + Platform.copyMemory(baseObject, baseOffset, bytes, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); + return bytes; + } + } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + byte[] bytes = getBytes(); + out.writeInt(bytes.length); + out.writeInt(this.numElements); + out.write(bytes); + } + + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + this.baseOffset = BYTE_ARRAY_OFFSET; + this.sizeInBytes = in.readInt(); + this.numElements = in.readInt(); + this.elementOffset = baseOffset + calculateHeaderPortionInBytes(this.numElements); + this.baseObject = new byte[sizeInBytes]; + in.readFully((byte[]) baseObject); + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala index 29a2bcd643f1..db25d2fb97e5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala @@ -19,11 +19,13 @@ package org.apache.spark.sql.catalyst.util import java.time.ZoneId -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer.JavaSerializer import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class UnsafeArraySuite extends SparkFunSuite { @@ -210,4 +212,17 @@ class UnsafeArraySuite extends SparkFunSuite { val doubleEncoder = ExpressionEncoder[Array[Double]].resolveAndBind() assert(doubleEncoder.toRow(doubleArray).getArray(0).toDoubleArray.sameElements(doubleArray)) } + + test("unsafe java serialization") { + val offset = 32 + val data = new Array[Byte](1024) + Platform.putLong(data, offset, 1) + val arrayData = new UnsafeArrayData() + arrayData.pointTo(data, offset, data.length) + arrayData.setLong(0, 19285) + val ser = new JavaSerializer(new SparkConf).newInstance() + val arrayDataSer = ser.deserialize[UnsafeArrayData](ser.serialize(arrayData)) + assert(arrayDataSer.getLong(0) == 19285) + assert(arrayDataSer.getBaseObject.asInstanceOf[Array[Byte]].length == 1024) + } }