Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql.catalyst.expressions;

import java.util.Arrays;
import java.util.Iterator;

import org.apache.spark.sql.catalyst.InternalRow;
Expand Down Expand Up @@ -142,14 +141,7 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) {
final int groupingKeySize = groupingKeyToUnsafeRowConverter.getSizeRequirement(groupingKey);
// Make sure that the buffer is large enough to hold the key. If it's not, grow it:
if (groupingKeySize > groupingKeyConversionScratchSpace.length) {
// This new array will be initially zero, so there's no need to zero it out here
groupingKeyConversionScratchSpace = new byte[groupingKeySize];
} else {
// Zero out the buffer that's used to hold the current row. This is necessary in order
// to ensure that rows hash properly, since garbage data from the previous row could
// otherwise end up as padding in this row. As a performance optimization, we only zero out
// the portion of the buffer that we'll actually write to.
Arrays.fill(groupingKeyConversionScratchSpace, 0, groupingKeySize, (byte) 0);
}
final int actualGroupingKeySize = groupingKeyToUnsafeRowConverter.writeRow(
groupingKey,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that it's safe to remove this zeroing as long as we assume that a) every column will actually end up writing to the row, and b) for null columns, we zero out the fixed length section, and c) we zero the bitset when starting to write the row. All three of these assumptions seem to hold, so this seems fine.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@
* In the `values` region, we store one 8-byte word per field. For fields that hold fixed-length
* primitive types, such as long, double, or int, we store the value directly in the word. For
* fields with non-primitive or variable-length values, we store a relative offset (w.r.t. the
* base address of the row) that points to the beginning of the variable-length field.
* base address of the row) that points to the beginning of the variable-length field, and length
* (they are combined into a long).
*
* Instances of `UnsafeRow` act as pointers to row data stored in this format.
*/
Expand Down Expand Up @@ -92,6 +93,7 @@ public static int calculateBitSetWidthInBytes(int numFields) {
*/
public static final Set<DataType> readableFieldTypes;

// TODO: support DecimalType
static {
settableFieldTypes = Collections.unmodifiableSet(
new HashSet<DataType>(
Expand All @@ -111,7 +113,8 @@ public static int calculateBitSetWidthInBytes(int numFields) {
// We support get() on a superset of the types for which we support set():
final Set<DataType> _readableFieldTypes = new HashSet<DataType>(
Arrays.asList(new DataType[]{
StringType
StringType,
BinaryType
}));
_readableFieldTypes.addAll(settableFieldTypes);
readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes);
Expand Down Expand Up @@ -221,11 +224,6 @@ public void setFloat(int ordinal, float value) {
PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value);
}

@Override
public void setString(int ordinal, String value) {
throw new UnsupportedOperationException();
}

@Override
public int size() {
return numFields;
Expand All @@ -249,6 +247,8 @@ public Object get(int i) {
return null;
} else if (dataType == StringType) {
return getUTF8String(i);
} else if (dataType == BinaryType) {
return getBinary(i);
} else {
throw new UnsupportedOperationException();
}
Expand Down Expand Up @@ -311,19 +311,23 @@ public double getDouble(int i) {
}

public UTF8String getUTF8String(int i) {
return UTF8String.fromBytes(getBinary(i));
}

public byte[] getBinary(int i) {
assertIndexIsValid(i);
final long offsetToStringSize = getLong(i);
final int stringSizeInBytes =
(int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offsetToStringSize);
final byte[] strBytes = new byte[stringSizeInBytes];
final long offsetAndSize = getLong(i);
final int offset = (int)(offsetAndSize >> 32);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to mask out the upper 32 bits before converting to a long? I guess the uppermost bit probably can't be 1 because the offset can't be negative, so I guess we don't need to worry about sign-extension during the shift.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No.

final int size = (int)(offsetAndSize & ((1L << 32) - 1));
final byte[] bytes = new byte[size];
PlatformDependent.copyMemory(
baseObject,
baseOffset + offsetToStringSize + 8, // The `+ 8` is to skip past the size to get the data
strBytes,
baseOffset + offset,
bytes,
PlatformDependent.BYTE_ARRAY_OFFSET,
stringSizeInBytes
size
);
return UTF8String.fromBytes(strBytes);
return bytes;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.PlatformDependent
import org.apache.spark.unsafe.array.ByteArrayMethods
Expand Down Expand Up @@ -72,6 +70,19 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) {
*/
def writeRow(row: InternalRow, baseObject: Object, baseOffset: Long): Int = {
unsafeRow.pointTo(baseObject, baseOffset, writers.length, null)

if (writers.length > 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to declare a row which has no columns? I'm just wondering if we ever run into a case where we need to worry about writers.length == 0 and whether we might also allocate bitset space in that case due to rounding or something.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's possible to have a Row with no columns:

>>> df.groupBy(df.id).agg({}).collect()
[Row(id=1), Row(id=2), Row(id=3), Row(id=4)]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we need to change this to >= then. If you look at the line below, on line 77, when writers.length == 0 then n == 0 but the loop implies that we'll still allocate one word for null-tracking bits.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, since writers.length can't be negative, maybe we should just remove this if and unconditionally null out the bitset.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I had checked that.

// zero-out the bitset
var n = writers.length / 64
while (n >= 0) {
PlatformDependent.UNSAFE.putLong(
unsafeRow.getBaseObject,
unsafeRow.getBaseOffset + n * 8,
0L)
n -= 1
}
}

var fieldNumber = 0
var appendCursor: Int = fixedLengthSize
while (fieldNumber < writers.length) {
Expand Down Expand Up @@ -122,6 +133,7 @@ private object UnsafeColumnWriter {
case FloatType => FloatUnsafeColumnWriter
case DoubleType => DoubleUnsafeColumnWriter
case StringType => StringUnsafeColumnWriter
case BinaryType => BinaryUnsafeColumnWriter
case DateType => IntUnsafeColumnWriter
case TimestampType => LongUnsafeColumnWriter
case t =>
Expand All @@ -141,6 +153,7 @@ private object LongUnsafeColumnWriter extends LongUnsafeColumnWriter
private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter
private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter
private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter
private object BinaryUnsafeColumnWriter extends BinaryUnsafeColumnWriter

private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter {
// Primitives don't write to the variable-length region:
Expand Down Expand Up @@ -235,30 +248,47 @@ private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWr
}
}

private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter {
private abstract class BytesUnsafeColumnWriter extends UnsafeColumnWriter {

def getBytes(source: InternalRow, column: Int): Array[Byte]

def getSize(source: InternalRow, column: Int): Int = {
val numBytes = source.get(column).asInstanceOf[UTF8String].getBytes.length
8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
val numBytes = getBytes(source, column).length
ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
}

override def write(
source: InternalRow,
target: UnsafeRow,
column: Int,
appendCursor: Int): Int = {
val value = source.get(column).asInstanceOf[UTF8String]
val baseObject = target.getBaseObject
val baseOffset = target.getBaseOffset
val numBytes = value.getBytes.length
PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + appendCursor, numBytes)
val offset = target.getBaseOffset + appendCursor
val bytes = getBytes(source, column)
val numBytes = bytes.length
if ((numBytes & 0x07) > 0) {
// zero-out the padding bytes
PlatformDependent.UNSAFE.putLong(target.getBaseObject, offset + ((numBytes >> 3) << 3), 0L)
}
PlatformDependent.copyMemory(
value.getBytes,
bytes,
PlatformDependent.BYTE_ARRAY_OFFSET,
baseObject,
baseOffset + appendCursor + 8,
target.getBaseObject,
offset,
numBytes
)
target.setLong(column, appendCursor)
8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
target.setLong(column, (appendCursor.toLong << 32L) | numBytes.toLong)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that numBytes isn't going to be larger than 2 gigabytes, so I guess we're fine to not mask it here.

ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
}
}

private class StringUnsafeColumnWriter private() extends BytesUnsafeColumnWriter {
def getBytes(source: InternalRow, column: Int): Array[Byte] = {
source.getAs[UTF8String](column).getBytes
}
}

private class BinaryUnsafeColumnWriter private() extends BytesUnsafeColumnWriter {
def getBytes(source: InternalRow, column: Int): Array[Byte] = {
source.getAs[Array[Byte]](column)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ import java.util.Arrays
import org.scalatest.Matchers

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.PlatformDependent
import org.apache.spark.unsafe.array.ByteArrayMethods

Expand Down Expand Up @@ -52,19 +52,19 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
unsafeRow.getInt(2) should be (2)
}

test("basic conversion with primitive and string types") {
val fieldTypes: Array[DataType] = Array(LongType, StringType, StringType)
test("basic conversion with primitive, string and binary types") {
val fieldTypes: Array[DataType] = Array(LongType, StringType, BinaryType)
val converter = new UnsafeRowConverter(fieldTypes)

val row = new SpecificMutableRow(fieldTypes)
row.setLong(0, 0)
row.setString(1, "Hello")
row.setString(2, "World")
row.update(2, "World".getBytes)

val sizeRequired: Int = converter.getSizeRequirement(row)
sizeRequired should be (8 + (8 * 3) +
ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length + 8) +
ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length + 8))
ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length) +
ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length))
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET)
numBytesWritten should be (sizeRequired)
Expand All @@ -73,7 +73,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
unsafeRow.getLong(0) should be (0)
unsafeRow.getString(1) should be ("Hello")
unsafeRow.getString(2) should be ("World")
unsafeRow.getBinary(2) should be ("World".getBytes)
}

test("basic conversion with primitive, string, date and timestamp types") {
Expand All @@ -88,7 +88,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {

val sizeRequired: Int = converter.getSizeRequirement(row)
sizeRequired should be (8 + (8 * 4) +
ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length + 8))
ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length))
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET)
numBytesWritten should be (sizeRequired)
Expand Down