From 447dea051b13da73e0b84e3de72fd16e6d466765 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 19 Jun 2015 16:44:30 -0700 Subject: [PATCH 1/6] support binaryType in UnsafeRow --- .../sql/catalyst/expressions/UnsafeRow.java | 33 ++++++------- .../expressions/UnsafeRowConverter.scala | 46 +++++++++++++++---- .../expressions/UnsafeRowConverterSuite.scala | 16 +++---- 3 files changed, 61 insertions(+), 34 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index c4b7f8490a05..9bd8687e6c2e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -92,6 +92,7 @@ public static int calculateBitSetWidthInBytes(int numFields) { */ public static final Set readableFieldTypes; + // TODO: support DecimalType static { settableFieldTypes = Collections.unmodifiableSet( new HashSet( @@ -111,7 +112,8 @@ public static int calculateBitSetWidthInBytes(int numFields) { // We support get() on a superset of the types for which we support set(): final Set _readableFieldTypes = new HashSet( Arrays.asList(new DataType[]{ - StringType + StringType, + BinaryType })); _readableFieldTypes.addAll(settableFieldTypes); readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes); @@ -221,11 +223,6 @@ public void setFloat(int ordinal, float value) { PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value); } - @Override - public void setString(int ordinal, String value) { - throw new UnsupportedOperationException(); - } - @Override public int size() { return numFields; @@ -249,6 +246,8 @@ public Object get(int i) { return null; } else if (dataType == StringType) { return getUTF8String(i); + } else if (dataType == BinaryType) { + return getBinary(i); } else { throw new UnsupportedOperationException(); } @@ -311,21 +310,23 @@ public double getDouble(int i) { } public UTF8String getUTF8String(int i) { + return UTF8String.fromBytes(getBinary(i)); + } + + public byte[] getBinary(int i) { assertIndexIsValid(i); - final UTF8String str = new UTF8String(); - final long offsetToStringSize = getLong(i); - final int stringSizeInBytes = - (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offsetToStringSize); - final byte[] strBytes = new byte[stringSizeInBytes]; + final long offsetAndSize = getLong(i); + final int offset = (int)(offsetAndSize >> 32); + final int size = (int)(offsetAndSize & ((1L << 32) - 1)); + final byte[] bytes = new byte[size]; PlatformDependent.copyMemory( baseObject, - baseOffset + offsetToStringSize + 8, // The `+ 8` is to skip past the size to get the data - strBytes, + baseOffset + offset, + bytes, PlatformDependent.BYTE_ARRAY_OFFSET, - stringSizeInBytes + size ); - str.set(strBytes); - return str; + return bytes; } @Override diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala index 72f740ecaead..6f154915647f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.util.DateUtils -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.array.ByteArrayMethods @@ -122,6 +120,7 @@ private object UnsafeColumnWriter { case FloatType => FloatUnsafeColumnWriter case DoubleType => DoubleUnsafeColumnWriter case StringType => StringUnsafeColumnWriter + case BinaryType => BinaryUnsafeColumnWriter case DateType => IntUnsafeColumnWriter case TimestampType => LongUnsafeColumnWriter case t => @@ -141,6 +140,7 @@ private object LongUnsafeColumnWriter extends LongUnsafeColumnWriter private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter +private object BinaryUnsafeColumnWriter extends BinaryUnsafeColumnWriter private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter { // Primitives don't write to the variable-length region: @@ -238,7 +238,7 @@ private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWr private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter { def getSize(source: InternalRow, column: Int): Int = { val numBytes = source.get(column).asInstanceOf[UTF8String].getBytes.length - 8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) } override def write( @@ -246,19 +246,45 @@ private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter { target: UnsafeRow, column: Int, appendCursor: Int): Int = { - val value = source.get(column).asInstanceOf[UTF8String] + val value = source.get(column).asInstanceOf[UTF8String].getBytes val baseObject = target.getBaseObject val baseOffset = target.getBaseOffset - val numBytes = value.getBytes.length - PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + appendCursor, numBytes) + val numBytes = value.length PlatformDependent.copyMemory( - value.getBytes, + value, PlatformDependent.BYTE_ARRAY_OFFSET, baseObject, - baseOffset + appendCursor + 8, + baseOffset + appendCursor, numBytes ) - target.setLong(column, appendCursor) - 8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) + target.setLong(column, (appendCursor.toLong << 32) | numBytes.toLong) + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) } } + +private class BinaryUnsafeColumnWriter private() extends UnsafeColumnWriter { + def getSize(source: InternalRow, column: Int): Int = { + val numBytes = source.getAs[Array[Byte]](column).length + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) + } + + override def write( + source: InternalRow, + target: UnsafeRow, + column: Int, + appendCursor: Int): Int = { + val value = source.getAs[Array[Byte]](column) + val baseObject = target.getBaseObject + val baseOffset = target.getBaseOffset + val numBytes = value.length + PlatformDependent.copyMemory( + value, + PlatformDependent.BYTE_ARRAY_OFFSET, + baseObject, + baseOffset + appendCursor, + numBytes + ) + target.setLong(column, (appendCursor.toLong << 32) | numBytes.toLong) + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) + } +} \ No newline at end of file diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 721ef8a22608..d8f3351d6dff 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -23,8 +23,8 @@ import java.util.Arrays import org.scalatest.Matchers import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.util.DateUtils +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.array.ByteArrayMethods @@ -52,19 +52,19 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { unsafeRow.getInt(2) should be (2) } - test("basic conversion with primitive and string types") { - val fieldTypes: Array[DataType] = Array(LongType, StringType, StringType) + test("basic conversion with primitive, string and binary types") { + val fieldTypes: Array[DataType] = Array(LongType, StringType, BinaryType) val converter = new UnsafeRowConverter(fieldTypes) val row = new SpecificMutableRow(fieldTypes) row.setLong(0, 0) row.setString(1, "Hello") - row.setString(2, "World") + row.update(2, "World".getBytes) val sizeRequired: Int = converter.getSizeRequirement(row) sizeRequired should be (8 + (8 * 3) + - ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length + 8) + - ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length + 8)) + ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length) + + ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET) numBytesWritten should be (sizeRequired) @@ -73,7 +73,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) unsafeRow.getLong(0) should be (0) unsafeRow.getString(1) should be ("Hello") - unsafeRow.getString(2) should be ("World") + unsafeRow.getBinary(2) should be ("World".getBytes) } test("basic conversion with primitive, string, date and timestamp types") { @@ -88,7 +88,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val sizeRequired: Int = converter.getSizeRequirement(row) sizeRequired should be (8 + (8 * 4) + - ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length + 8)) + ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET) numBytesWritten should be (sizeRequired) From 6abfe93adb9fb6a48444b1f2ae7a8234af15b6a9 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 19 Jun 2015 18:17:25 -0700 Subject: [PATCH 2/6] fix style --- .../spark/sql/catalyst/expressions/UnsafeRowConverter.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala index 6f154915647f..fc4fe3bbcd65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -287,4 +287,4 @@ private class BinaryUnsafeColumnWriter private() extends UnsafeColumnWriter { target.setLong(column, (appendCursor.toLong << 32) | numBytes.toLong) ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) } -} \ No newline at end of file +} From 22e4c0afdb86a84316d9703109c098f3ec8337fa Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 19 Jun 2015 23:39:14 -0700 Subject: [PATCH 3/6] zero-out padding bytes --- .../expressions/UnsafeRowConverter.scala | 55 ++++++++----------- 1 file changed, 23 insertions(+), 32 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala index fc4fe3bbcd65..fe78fff1a778 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -235,9 +235,12 @@ private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWr } } -private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter { +private abstract class BytesUnsafeColumnWriter extends UnsafeColumnWriter { + + def getBytes(source: InternalRow, column: Int): Array[Byte] + def getSize(source: InternalRow, column: Int): Int = { - val numBytes = source.get(column).asInstanceOf[UTF8String].getBytes.length + val numBytes = getBytes(source, column).length ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) } @@ -246,45 +249,33 @@ private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter { target: UnsafeRow, column: Int, appendCursor: Int): Int = { - val value = source.get(column).asInstanceOf[UTF8String].getBytes - val baseObject = target.getBaseObject - val baseOffset = target.getBaseOffset - val numBytes = value.length + val offset = target.getBaseOffset + appendCursor + val bytes = getBytes(source, column) + val numBytes = bytes.length + if ((numBytes & 0x07) > 0) { + // zero-out the padding bytes + PlatformDependent.UNSAFE.putLong(target.getBaseObject, offset + (numBytes >> 3 << 3), 0L) + } PlatformDependent.copyMemory( - value, + bytes, PlatformDependent.BYTE_ARRAY_OFFSET, - baseObject, - baseOffset + appendCursor, + target.getBaseObject, + offset, numBytes ) - target.setLong(column, (appendCursor.toLong << 32) | numBytes.toLong) + target.setLong(column, (appendCursor.toLong << 32L) | numBytes.toLong) ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) } } -private class BinaryUnsafeColumnWriter private() extends UnsafeColumnWriter { - def getSize(source: InternalRow, column: Int): Int = { - val numBytes = source.getAs[Array[Byte]](column).length - ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) +private class StringUnsafeColumnWriter private() extends BytesUnsafeColumnWriter { + def getBytes(source: InternalRow, column: Int): Array[Byte] = { + source.getAs[UTF8String](column).getBytes } +} - override def write( - source: InternalRow, - target: UnsafeRow, - column: Int, - appendCursor: Int): Int = { - val value = source.getAs[Array[Byte]](column) - val baseObject = target.getBaseObject - val baseOffset = target.getBaseOffset - val numBytes = value.length - PlatformDependent.copyMemory( - value, - PlatformDependent.BYTE_ARRAY_OFFSET, - baseObject, - baseOffset + appendCursor, - numBytes - ) - target.setLong(column, (appendCursor.toLong << 32) | numBytes.toLong) - ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) +private class BinaryUnsafeColumnWriter private() extends BytesUnsafeColumnWriter { + def getBytes(source: InternalRow, column: Int): Array[Byte] = { + source.getAs[Array[Byte]](column) } } From 180b49d671606d468c825057b3c6d6af4674955b Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 19 Jun 2015 23:58:54 -0700 Subject: [PATCH 4/6] fix zero-out --- .../expressions/UnsafeFixedWidthAggregationMap.java | 8 -------- .../catalyst/expressions/UnsafeRowConverter.scala | 13 +++++++++++++ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java index f7849ebebc57..83f2a312972f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions; -import java.util.Arrays; import java.util.Iterator; import org.apache.spark.sql.catalyst.InternalRow; @@ -142,14 +141,7 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { final int groupingKeySize = groupingKeyToUnsafeRowConverter.getSizeRequirement(groupingKey); // Make sure that the buffer is large enough to hold the key. If it's not, grow it: if (groupingKeySize > groupingKeyConversionScratchSpace.length) { - // This new array will be initially zero, so there's no need to zero it out here groupingKeyConversionScratchSpace = new byte[groupingKeySize]; - } else { - // Zero out the buffer that's used to hold the current row. This is necessary in order - // to ensure that rows hash properly, since garbage data from the previous row could - // otherwise end up as padding in this row. As a performance optimization, we only zero out - // the portion of the buffer that we'll actually write to. - Arrays.fill(groupingKeyConversionScratchSpace, 0, groupingKeySize, (byte) 0); } final int actualGroupingKeySize = groupingKeyToUnsafeRowConverter.writeRow( groupingKey, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala index fe78fff1a778..fab1afd87f35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -70,6 +70,19 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) { */ def writeRow(row: InternalRow, baseObject: Object, baseOffset: Long): Int = { unsafeRow.pointTo(baseObject, baseOffset, writers.length, null) + + if (writers.length > 0) { + // zero-out the bitset + var n = writers.length / 64 + while (n >= 0) { + PlatformDependent.UNSAFE.putLong( + unsafeRow.getBaseObject, + unsafeRow.getBaseOffset + n * 8, + 0L) + n -= 1 + } + } + var fieldNumber = 0 var appendCursor: Int = fixedLengthSize while (fieldNumber < writers.length) { From 519f6988418a71c42cfe6465aa8228724aa20afe Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 22 Jun 2015 12:58:58 -0700 Subject: [PATCH 5/6] address comment --- .../spark/sql/catalyst/expressions/UnsafeRowConverter.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala index fab1afd87f35..89adaf053b1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -267,7 +267,7 @@ private abstract class BytesUnsafeColumnWriter extends UnsafeColumnWriter { val numBytes = bytes.length if ((numBytes & 0x07) > 0) { // zero-out the padding bytes - PlatformDependent.UNSAFE.putLong(target.getBaseObject, offset + (numBytes >> 3 << 3), 0L) + PlatformDependent.UNSAFE.putLong(target.getBaseObject, offset + ((numBytes >> 3) << 3), 0L) } PlatformDependent.copyMemory( bytes, From d68706fee5261d253278ff1d3af83a1a2846c5a7 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 22 Jun 2015 13:02:04 -0700 Subject: [PATCH 6/6] update comment --- .../org/apache/spark/sql/catalyst/expressions/UnsafeRow.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 9bd8687e6c2e..bb2f2079b40f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -47,7 +47,8 @@ * In the `values` region, we store one 8-byte word per field. For fields that hold fixed-length * primitive types, such as long, double, or int, we store the value directly in the word. For * fields with non-primitive or variable-length values, we store a relative offset (w.r.t. the - * base address of the row) that points to the beginning of the variable-length field. + * base address of the row) that points to the beginning of the variable-length field, and length + * (they are combined into a long). * * Instances of `UnsafeRow` act as pointers to row data stored in this format. */