Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,19 @@

package org.apache.spark.sql.catalyst.expressions;

import java.io.IOException;
import java.io.OutputStream;
import java.io.*;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;

import com.esotericsoftware.kryo.Kryo;
import com.esotericsoftware.kryo.KryoSerializable;
import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;

import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
Expand All @@ -35,6 +39,7 @@
import org.apache.spark.unsafe.types.UTF8String;

import static org.apache.spark.sql.types.DataTypes.*;
import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET;

/**
* An Unsafe implementation of Row which is backed by raw memory instead of Java objects.
Expand All @@ -52,7 +57,7 @@
*
* Instances of `UnsafeRow` act as pointers to row data stored in this format.
*/
public final class UnsafeRow extends MutableRow {
public final class UnsafeRow extends MutableRow implements Externalizable, KryoSerializable {

//////////////////////////////////////////////////////////////////////////////
// Static methods
Expand Down Expand Up @@ -596,4 +601,40 @@ public boolean anyNull() {
public void writeToMemory(Object target, long targetOffset) {
Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes);
}

@Override
public void writeExternal(ObjectOutput out) throws IOException {
byte[] bytes = getBytes();
out.writeInt(bytes.length);
out.writeInt(this.numFields);
out.write(bytes);
}

@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
this.baseOffset = BYTE_ARRAY_OFFSET;
this.sizeInBytes = in.readInt();
this.numFields = in.readInt();
this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields);
this.baseObject = new byte[sizeInBytes];
in.readFully((byte[]) baseObject);
}

@Override
public void write(Kryo kryo, Output out) {
byte[] bytes = getBytes();
out.writeInt(bytes.length);
out.writeInt(this.numFields);
out.write(bytes);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this just call writeExternal(out)?

}

@Override
public void read(Kryo kryo, Input in) {
this.baseOffset = BYTE_ARRAY_OFFSET;
this.sizeInBytes = in.readInt();
this.numFields = in.readInt();
this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields);
this.baseObject = new byte[sizeInBytes];
in.read((byte[]) baseObject);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ package org.apache.spark.sql

import java.io.ByteArrayOutputStream

import org.apache.spark.SparkFunSuite
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.{KryoSerializer, JavaSerializer}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection}
import org.apache.spark.sql.types._
Expand All @@ -29,6 +30,32 @@ import org.apache.spark.unsafe.types.UTF8String

class UnsafeRowSuite extends SparkFunSuite {

test("UnsafeRow Java serialization") {
// serializing an UnsafeRow pointing to a large buffer should only serialize the relevant data
val data = new Array[Byte](1024)
val row = new UnsafeRow
row.pointTo(data, 1, 16)
row.setLong(0, 19285)

val ser = new JavaSerializer(new SparkConf).newInstance()
val row1 = ser.deserialize[UnsafeRow](ser.serialize(row))
assert(row1.getLong(0) == 19285)
assert(row1.getBaseObject().asInstanceOf[Array[Byte]].length == 16)
}

test("UnsafeRow Kryo serialization") {
// serializing an UnsafeRow pointing to a large buffer should only serialize the relevant data
val data = new Array[Byte](1024)
val row = new UnsafeRow
row.pointTo(data, 1, 16)
row.setLong(0, 19285)

val ser = new KryoSerializer(new SparkConf).newInstance()
val row1 = ser.deserialize[UnsafeRow](ser.serialize(row))
assert(row1.getLong(0) == 19285)
assert(row1.getBaseObject().asInstanceOf[Array[Byte]].length == 16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

===

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This actually doesn't matter anymore with new versions of ScalaTest.

}

test("bitset width calculation") {
assert(UnsafeRow.calculateBitSetWidthInBytes(0) === 0)
assert(UnsafeRow.calculateBitSetWidthInBytes(1) === 8)
Expand Down