From 0379f7cbff5cf58abe5fb5c1af6a64cd17e578fd Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 17 Mar 2018 21:10:20 +0100 Subject: [PATCH 01/11] initial commit --- .../sql/kafka010/KafkaContinuousReader.scala | 3 - .../KafkaRecordToUnsafeRowConverter.scala | 9 +- .../expressions/codegen/BufferHolder.java | 24 ++- .../codegen/UnsafeArrayWriter.java | 125 +++---------- .../expressions/codegen/UnsafeRowWriter.java | 176 ++++++------------ .../expressions/codegen/UnsafeWriter.java | 146 ++++++++++++++- .../InterpretedUnsafeProjection.scala | 77 ++++---- .../codegen/GenerateUnsafeProjection.scala | 100 +++++----- .../RowBasedKeyValueBatchSuite.java | 22 +-- .../aggregate/RowBasedHashMapGenerator.scala | 11 +- .../columnar/GenerateColumnAccessor.scala | 7 +- .../datasources/text/TextFileFormat.scala | 11 +- 12 files changed, 358 insertions(+), 353 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index e7e27876088f..f26c134c2f6e 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -27,13 +27,10 @@ import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType -import org.apache.spark.unsafe.types.UTF8String /** * A [[ContinuousReader]] for data from kafka. diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala index 1acdd5612574..05d2f9dc74b5 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala @@ -20,18 +20,17 @@ package org.apache.spark.sql.kafka010 import org.apache.kafka.clients.consumer.ConsumerRecord import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.unsafe.types.UTF8String /** A simple class for converting Kafka ConsumerRecord to UnsafeRow */ private[kafka010] class KafkaRecordToUnsafeRowConverter { private val sharedRow = new UnsafeRow(7) - private val bufferHolder = new BufferHolder(sharedRow) - private val rowWriter = new UnsafeRowWriter(bufferHolder, 7) + private val rowWriter = new UnsafeRowWriter(sharedRow) def toUnsafeRow(record: ConsumerRecord[Array[Byte], Array[Byte]]): UnsafeRow = { - bufferHolder.reset() + rowWriter.reset() if (record.key == null) { rowWriter.setNullAt(0) @@ -46,7 +45,7 @@ private[kafka010] class KafkaRecordToUnsafeRowConverter { 5, DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(record.timestamp))) rowWriter.write(6, record.timestampType.id) - sharedRow.setTotalSize(bufferHolder.totalSize) + rowWriter.setTotalSize() sharedRow } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index 259976118c12..3eb3e3993c3d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -31,24 +31,24 @@ * for each incoming record, we should call `reset` of BufferHolder instance before write the record * and reuse the data buffer. * - * Generally we should call `UnsafeRow.setTotalSize` and pass in `BufferHolder.totalSize` to update + * Generally we should call `UnsafeRowWriter.setTotalSize` using `BufferHolder.totalSize` to update * the size of the result row, after writing a record to the buffer. However, we can skip this step * if the fields of row are all fixed-length, as the size of result row is also fixed. */ -public class BufferHolder { +public final class BufferHolder { private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; - public byte[] buffer; - public int cursor = Platform.BYTE_ARRAY_OFFSET; + private byte[] buffer; + private int cursor = Platform.BYTE_ARRAY_OFFSET; private final UnsafeRow row; private final int fixedSize; - public BufferHolder(UnsafeRow row) { + BufferHolder(UnsafeRow row) { this(row, 64); } - public BufferHolder(UnsafeRow row, int initialSize) { + BufferHolder(UnsafeRow row, int initialSize) { int bitsetWidthInBytes = UnsafeRow.calculateBitSetWidthInBytes(row.numFields()); if (row.numFields() > (ARRAY_MAX - initialSize - bitsetWidthInBytes) / 8) { throw new UnsupportedOperationException( @@ -64,7 +64,7 @@ public BufferHolder(UnsafeRow row, int initialSize) { /** * Grows the buffer by at least neededSize and points the row to the buffer. */ - public void grow(int neededSize) { + void grow(int neededSize) { if (neededSize > ARRAY_MAX - totalSize()) { throw new UnsupportedOperationException( "Cannot grow BufferHolder by size " + neededSize + " because the size after growing " + @@ -86,11 +86,17 @@ public void grow(int neededSize) { } } - public void reset() { + byte[] buffer() { return buffer; } + + int getCursor() { return cursor; } + + void addCursor(int val) { cursor += val; } + + void reset() { cursor = Platform.BYTE_ARRAY_OFFSET + fixedSize; } - public int totalSize() { + int totalSize() { return cursor - Platform.BYTE_ARRAY_OFFSET; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java index 82cd1b24607e..1a15b6cc57fd 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -21,8 +21,6 @@ import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; -import org.apache.spark.unsafe.types.CalendarInterval; -import org.apache.spark.unsafe.types.UTF8String; import static org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.calculateHeaderPortionInBytes; @@ -32,11 +30,6 @@ */ public final class UnsafeArrayWriter extends UnsafeWriter { - private BufferHolder holder; - - // The offset of the global buffer where we start to write this array. - private int startingOffset; - // The number of elements in this array private int numElements; @@ -47,13 +40,16 @@ private void assertIndexIsValid(int index) { assert index < numElements : "index (" + index + ") should < " + numElements; } - public void initialize(BufferHolder holder, int numElements, int elementSize) { + public UnsafeArrayWriter(UnsafeWriter writer) { + super(writer.getBufferHolder()); + } + + public void initialize(int numElements, int elementSize) { // We need 8 bytes to store numElements in header this.numElements = numElements; this.headerInBytes = calculateHeaderPortionInBytes(numElements); - this.holder = holder; - this.startingOffset = holder.cursor; + this.startingOffset = cursor(); // Grows the global buffer ahead for header and fixed size data. int fixedPartInBytes = @@ -61,112 +57,102 @@ public void initialize(BufferHolder holder, int numElements, int elementSize) { holder.grow(headerInBytes + fixedPartInBytes); // Write numElements and clear out null bits to header - Platform.putLong(holder.buffer, startingOffset, numElements); + Platform.putLong(buffer(), startingOffset, numElements); for (int i = 8; i < headerInBytes; i += 8) { - Platform.putLong(holder.buffer, startingOffset + i, 0L); + Platform.putLong(buffer(), startingOffset + i, 0L); } // fill 0 into reminder part of 8-bytes alignment in unsafe array for (int i = elementSize * numElements; i < fixedPartInBytes; i++) { - Platform.putByte(holder.buffer, startingOffset + headerInBytes + i, (byte) 0); + Platform.putByte(buffer(), startingOffset + headerInBytes + i, (byte) 0); } - holder.cursor += (headerInBytes + fixedPartInBytes); + addCursor(headerInBytes + fixedPartInBytes); } - private void zeroOutPaddingBytes(int numBytes) { - if ((numBytes & 0x07) > 0) { - Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L); - } + protected long getOffset(int ordinal, int elementSize) { + return getElementOffset(ordinal, elementSize); } private long getElementOffset(int ordinal, int elementSize) { return startingOffset + headerInBytes + ordinal * elementSize; } + @Override public void setOffsetAndSize(int ordinal, int currentCursor, int size) { assertIndexIsValid(ordinal); - final long relativeOffset = currentCursor - startingOffset; - final long offsetAndSize = (relativeOffset << 32) | (long)size; - - write(ordinal, offsetAndSize); + _setOffsetAndSize(ordinal, currentCursor, size); } private void setNullBit(int ordinal) { assertIndexIsValid(ordinal); - BitSetMethods.set(holder.buffer, startingOffset + 8, ordinal); + BitSetMethods.set(buffer(), startingOffset + 8, ordinal); } public void setNull1Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), (byte)0); + Platform.putByte(buffer(), getElementOffset(ordinal, 1), (byte)0); } public void setNull2Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), (short)0); + Platform.putShort(buffer(), getElementOffset(ordinal, 2), (short)0); } public void setNull4Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), 0); + Platform.putInt(buffer(), getElementOffset(ordinal, 4), 0); } public void setNull8Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), (long)0); + Platform.putLong(buffer(), getElementOffset(ordinal, 8), (long)0); } public void setNull(int ordinal) { setNull8Bytes(ordinal); } public void write(int ordinal, boolean value) { assertIndexIsValid(ordinal); - Platform.putBoolean(holder.buffer, getElementOffset(ordinal, 1), value); + _write(getElementOffset(ordinal, 1), value); } public void write(int ordinal, byte value) { assertIndexIsValid(ordinal); - Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), value); + _write(getElementOffset(ordinal, 1), value); } public void write(int ordinal, short value) { assertIndexIsValid(ordinal); - Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), value); + _write(getElementOffset(ordinal, 2), value); } public void write(int ordinal, int value) { assertIndexIsValid(ordinal); - Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), value); + _write(getElementOffset(ordinal, 4), value); } public void write(int ordinal, long value) { assertIndexIsValid(ordinal); - Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), value); + _write(getElementOffset(ordinal, 8), value); } public void write(int ordinal, float value) { - if (Float.isNaN(value)) { - value = Float.NaN; - } assertIndexIsValid(ordinal); - Platform.putFloat(holder.buffer, getElementOffset(ordinal, 4), value); + _write(getElementOffset(ordinal, 4), value); } public void write(int ordinal, double value) { - if (Double.isNaN(value)) { - value = Double.NaN; - } assertIndexIsValid(ordinal); - Platform.putDouble(holder.buffer, getElementOffset(ordinal, 8), value); + _write(getElementOffset(ordinal, 8), value); } public void write(int ordinal, Decimal input, int precision, int scale) { // make sure Decimal object has the same scale as DecimalType assertIndexIsValid(ordinal); - if (input.changePrecision(precision, scale)) { + if (input != null && input.changePrecision(precision, scale)) { if (precision <= Decimal.MAX_LONG_DIGITS()) { write(ordinal, input.toUnscaledLong()); } else { @@ -180,65 +166,14 @@ public void write(int ordinal, Decimal input, int precision, int scale) { // Write the bytes to the variable length portion. Platform.copyMemory( - bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes); - setOffsetAndSize(ordinal, holder.cursor, numBytes); + bytes, Platform.BYTE_ARRAY_OFFSET, buffer(), cursor(), numBytes); + setOffsetAndSize(ordinal, numBytes); // move the cursor forward with 8-bytes boundary - holder.cursor += roundedSize; + addCursor(roundedSize); } } else { setNull(ordinal); } } - - public void write(int ordinal, UTF8String input) { - final int numBytes = input.numBytes(); - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - - // grow the global buffer before writing data. - holder.grow(roundedSize); - - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - input.writeToMemory(holder.buffer, holder.cursor); - - setOffsetAndSize(ordinal, holder.cursor, numBytes); - - // move the cursor forward. - holder.cursor += roundedSize; - } - - public void write(int ordinal, byte[] input) { - final int numBytes = input.length; - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length); - - // grow the global buffer before writing data. - holder.grow(roundedSize); - - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - Platform.copyMemory( - input, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes); - - setOffsetAndSize(ordinal, holder.cursor, numBytes); - - // move the cursor forward. - holder.cursor += roundedSize; - } - - public void write(int ordinal, CalendarInterval input) { - // grow the global buffer before writing data. - holder.grow(16); - - // Write the months and microseconds fields of Interval to the variable length portion. - Platform.putLong(holder.buffer, holder.cursor, input.months); - Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds); - - setOffsetAndSize(ordinal, holder.cursor, 16); - - // move the cursor forward. - holder.cursor += 16; - } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index 2620bbcfb87a..f8a78fa19555 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -20,10 +20,7 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; -import org.apache.spark.unsafe.types.CalendarInterval; -import org.apache.spark.unsafe.types.UTF8String; /** * A helper class to write data into global row buffer using `UnsafeRow` format. @@ -40,29 +37,45 @@ */ public final class UnsafeRowWriter extends UnsafeWriter { - private final BufferHolder holder; - // The offset of the global buffer where we start to write this row. - private int startingOffset; + private final UnsafeRow row; + private final int nullBitsSize; private final int fixedSize; - public UnsafeRowWriter(BufferHolder holder, int numFields) { - this.holder = holder; + public UnsafeRowWriter(UnsafeRow row, int initialBufferSize) { + this(row, new BufferHolder(row, initialBufferSize), row.numFields()); + } + + public UnsafeRowWriter(UnsafeRow row) { + this(row, new BufferHolder(row), row.numFields()); + } + + public UnsafeRowWriter(UnsafeWriter writer, int numFields) { + this(null, writer.getBufferHolder(), numFields); + } + + private UnsafeRowWriter(UnsafeRow row, BufferHolder holder, int numFields) { + super(holder); + this.row = row; this.nullBitsSize = UnsafeRow.calculateBitSetWidthInBytes(numFields); this.fixedSize = nullBitsSize + 8 * numFields; - this.startingOffset = holder.cursor; + this.startingOffset = cursor(); + } + + public void setTotalSize() { + row.setTotalSize(totalSize()); } /** * Resets the `startingOffset` according to the current cursor of row buffer, and clear out null * bits. This should be called before we write a new nested struct to the row buffer. */ - public void reset() { - this.startingOffset = holder.cursor; + public void resetRowWriter() { + this.startingOffset = cursor(); // grow the global buffer to make sure it has enough space to write fixed-length data. - holder.grow(fixedSize); - holder.cursor += fixedSize; + grow(fixedSize); + addCursor(fixedSize); zeroOutNullBytes(); } @@ -72,25 +85,17 @@ public void reset() { */ public void zeroOutNullBytes() { for (int i = 0; i < nullBitsSize; i += 8) { - Platform.putLong(holder.buffer, startingOffset + i, 0L); - } - } - - private void zeroOutPaddingBytes(int numBytes) { - if ((numBytes & 0x07) > 0) { - Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L); + Platform.putLong(buffer(), startingOffset + i, 0L); } } - public BufferHolder holder() { return holder; } - public boolean isNullAt(int ordinal) { - return BitSetMethods.isSet(holder.buffer, startingOffset, ordinal); + return BitSetMethods.isSet(buffer(), startingOffset, ordinal); } public void setNullAt(int ordinal) { - BitSetMethods.set(holder.buffer, startingOffset, ordinal); - Platform.putLong(holder.buffer, getFieldOffset(ordinal), 0L); + BitSetMethods.set(buffer(), startingOffset, ordinal); + Platform.putLong(buffer(), getFieldOffset(ordinal), 0L); } @Override @@ -113,71 +118,63 @@ public void setNull8Bytes(int ordinal) { setNullAt(ordinal); } - public long getFieldOffset(int ordinal) { - return startingOffset + nullBitsSize + 8 * ordinal; + @Override + protected final long getOffset(int oridinal, int elementSize) { + return getFieldOffset(oridinal); } - public void setOffsetAndSize(int ordinal, int size) { - setOffsetAndSize(ordinal, holder.cursor, size); + public long getFieldOffset(int ordinal) { + return startingOffset + nullBitsSize + 8 * ordinal; } + @Override public void setOffsetAndSize(int ordinal, int currentCursor, int size) { - final long relativeOffset = currentCursor - startingOffset; - final long fieldOffset = getFieldOffset(ordinal); - final long offsetAndSize = (relativeOffset << 32) | (long) size; - - Platform.putLong(holder.buffer, fieldOffset, offsetAndSize); + _setOffsetAndSize(ordinal, currentCursor, size); } public void write(int ordinal, boolean value) { final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putBoolean(holder.buffer, offset, value); + Platform.putLong(buffer(), offset, 0L); + _write(offset, value); } public void write(int ordinal, byte value) { final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putByte(holder.buffer, offset, value); + Platform.putLong(buffer(), offset, 0L); + _write(offset, value); } public void write(int ordinal, short value) { final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putShort(holder.buffer, offset, value); + Platform.putLong(buffer(), offset, 0L); + _write(offset, value); } public void write(int ordinal, int value) { final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putInt(holder.buffer, offset, value); + Platform.putLong(buffer(), offset, 0L); + _write(offset, value); } public void write(int ordinal, long value) { - Platform.putLong(holder.buffer, getFieldOffset(ordinal), value); + _write(getFieldOffset(ordinal), value); } public void write(int ordinal, float value) { - if (Float.isNaN(value)) { - value = Float.NaN; - } final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putFloat(holder.buffer, offset, value); + Platform.putLong(buffer(), offset, 0L); + _write(offset, value); } public void write(int ordinal, double value) { - if (Double.isNaN(value)) { - value = Double.NaN; - } - Platform.putDouble(holder.buffer, getFieldOffset(ordinal), value); + _write(getFieldOffset(ordinal), value); } public void write(int ordinal, Decimal input, int precision, int scale) { if (precision <= Decimal.MAX_LONG_DIGITS()) { // make sure Decimal object has the same scale as DecimalType if (input.changePrecision(precision, scale)) { - Platform.putLong(holder.buffer, getFieldOffset(ordinal), input.toUnscaledLong()); + write(ordinal, input.toUnscaledLong()); } else { setNullAt(ordinal); } @@ -185,82 +182,31 @@ public void write(int ordinal, Decimal input, int precision, int scale) { // grow the global buffer before writing data. holder.grow(16); - // zero-out the bytes - Platform.putLong(holder.buffer, holder.cursor, 0L); - Platform.putLong(holder.buffer, holder.cursor + 8, 0L); - // Make sure Decimal object has the same scale as DecimalType. // Note that we may pass in null Decimal object to set null for it. if (input == null || !input.changePrecision(precision, scale)) { - BitSetMethods.set(holder.buffer, startingOffset, ordinal); + // zero-out the bytes + Platform.putLong(buffer(), cursor(), 0L); + Platform.putLong(buffer(), cursor() + 8, 0L); + + BitSetMethods.set(buffer(), startingOffset, ordinal); // keep the offset for future update setOffsetAndSize(ordinal, 0); } else { final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); - assert bytes.length <= 16; + final int numBytes = bytes.length; + assert numBytes <= 16; + + zeroOutPaddingBytes(numBytes); // Write the bytes to the variable length portion. Platform.copyMemory( - bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length); + bytes, Platform.BYTE_ARRAY_OFFSET, buffer(), cursor(), numBytes); setOffsetAndSize(ordinal, bytes.length); } // move the cursor forward. - holder.cursor += 16; + addCursor(16); } } - - public void write(int ordinal, UTF8String input) { - final int numBytes = input.numBytes(); - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - - // grow the global buffer before writing data. - holder.grow(roundedSize); - - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - input.writeToMemory(holder.buffer, holder.cursor); - - setOffsetAndSize(ordinal, numBytes); - - // move the cursor forward. - holder.cursor += roundedSize; - } - - public void write(int ordinal, byte[] input) { - write(ordinal, input, 0, input.length); - } - - public void write(int ordinal, byte[] input, int offset, int numBytes) { - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - - // grow the global buffer before writing data. - holder.grow(roundedSize); - - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - Platform.copyMemory(input, Platform.BYTE_ARRAY_OFFSET + offset, - holder.buffer, holder.cursor, numBytes); - - setOffsetAndSize(ordinal, numBytes); - - // move the cursor forward. - holder.cursor += roundedSize; - } - - public void write(int ordinal, CalendarInterval input) { - // grow the global buffer before writing data. - holder.grow(16); - - // Write the months and microseconds fields of Interval to the variable length portion. - Platform.putLong(holder.buffer, holder.cursor, input.months); - Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds); - - setOffsetAndSize(ordinal, 16); - - // move the cursor forward. - holder.cursor += 16; - } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java index c94b5c7a367e..24473fa5f87f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions.codegen; import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -24,10 +26,62 @@ * Base class for writing Unsafe* structures. */ public abstract class UnsafeWriter { + // Keep internal buffer holder + protected final BufferHolder holder; + + // The offset of the global buffer where we start to write this structure. + protected int startingOffset; + + protected UnsafeWriter(BufferHolder holder) { + this.holder = holder; + } + + /** + * Accessor methods are delegated from BufferHolder class + */ + public final BufferHolder getBufferHolder() { + return holder; + } + + public final byte[] buffer() { return holder.buffer(); } + + public final void reset() { holder.reset(); } + + public final int totalSize() { return holder.totalSize(); } + + public final void grow(int neededSize) { holder.grow(neededSize); } + + public final int cursor() { return holder.getCursor(); } + + public final void addCursor(int val) { holder.addCursor(val); } + + + public abstract void setOffsetAndSize(int ordinal, int currentCursor, int size); + + protected void setOffsetAndSize(int ordinal, int size) { + setOffsetAndSize(ordinal, cursor(), size); + } + + protected void _setOffsetAndSize(int ordinal, int currentCursor, int size) { + final long relativeOffset = currentCursor - startingOffset; + final long offsetAndSize = (relativeOffset << 32) | (long)size; + + write(ordinal, offsetAndSize); + } + + protected final void zeroOutPaddingBytes(int numBytes) { + if ((numBytes & 0x07) > 0) { + Platform.putLong(buffer(), cursor() + ((numBytes >> 3) << 3), 0L); + } + } + + protected abstract long getOffset(int ordinal, int elementSize); + public abstract void setNull1Bytes(int ordinal); public abstract void setNull2Bytes(int ordinal); public abstract void setNull4Bytes(int ordinal); public abstract void setNull8Bytes(int ordinal); + public abstract void write(int ordinal, boolean value); public abstract void write(int ordinal, byte value); public abstract void write(int ordinal, short value); @@ -36,8 +90,92 @@ public abstract class UnsafeWriter { public abstract void write(int ordinal, float value); public abstract void write(int ordinal, double value); public abstract void write(int ordinal, Decimal input, int precision, int scale); - public abstract void write(int ordinal, UTF8String input); - public abstract void write(int ordinal, byte[] input); - public abstract void write(int ordinal, CalendarInterval input); - public abstract void setOffsetAndSize(int ordinal, int currentCursor, int size); + + public final void write(int ordinal, UTF8String input) { + final int numBytes = input.numBytes(); + final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); + + // grow the global buffer before writing data. + grow(roundedSize); + + zeroOutPaddingBytes(numBytes); + + // Write the bytes to the variable length portion. + input.writeToMemory(buffer(), cursor()); + + setOffsetAndSize(ordinal, numBytes); + + // move the cursor forward. + addCursor(roundedSize); + } + + public final void write(int ordinal, byte[] input) { + write(ordinal, input, 0, input.length); + } + + public final void write(int ordinal, byte[] input, int offset, int numBytes) { + final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length); + + // grow the global buffer before writing data. + grow(roundedSize); + + zeroOutPaddingBytes(numBytes); + + // Write the bytes to the variable length portion. + Platform.copyMemory( + input, Platform.BYTE_ARRAY_OFFSET + offset, buffer(), cursor(), numBytes); + + setOffsetAndSize(ordinal, numBytes); + + // move the cursor forward. + addCursor(roundedSize); + } + + public final void write(int ordinal, CalendarInterval input) { + // grow the global buffer before writing data. + grow(16); + + // Write the months and microseconds fields of Interval to the variable length portion. + Platform.putLong(buffer(), cursor(), input.months); + Platform.putLong(buffer(), cursor() + 8, input.microseconds); + + setOffsetAndSize(ordinal, 16); + + // move the cursor forward. + addCursor(16); + } + + protected final void _write(long offset, boolean value) { + Platform.putBoolean(buffer(), offset, value); + } + + protected final void _write(long offset, byte value) { + Platform.putByte(buffer(), offset, value); + } + + protected final void _write(long offset, short value) { + Platform.putShort(buffer(), offset, value); + } + + protected final void _write(long offset, int value) { + Platform.putInt(buffer(), offset, value); + } + + protected final void _write(long offset, long value) { + Platform.putLong(buffer(), offset, value); + } + + protected final void _write(long offset, float value) { + if (Float.isNaN(value)) { + value = Float.NaN; + } + Platform.putFloat(buffer(), offset, value); + } + + protected final void _write(long offset, double value) { + if (Double.isNaN(value)) { + value = Double.NaN; + } + Platform.putDouble(buffer(), offset, value); + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index 0da5ece7e47f..9ba67757de8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeArrayWriter, UnsafeRowWriter, UnsafeWriter} +import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeArrayWriter, UnsafeRowWriter, UnsafeWriter} import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types.{UserDefinedType, _} import org.apache.spark.unsafe.Platform @@ -45,14 +45,12 @@ class InterpretedUnsafeProjection(expressions: Array[Expression]) extends Unsafe /** The row returned by the projection. */ private[this] val result = new UnsafeRow(numFields) - /** The buffer which holds the resulting row's backing data. */ - private[this] val holder = new BufferHolder(result, numFields * 32) + /* The row writer for UnsafeRow result */ + private[this] val rowWriter = new UnsafeRowWriter(result, numFields * 32) /** The writer that writes the intermediate result to the result row. */ private[this] val writer: InternalRow => Unit = { - val rowWriter = new UnsafeRowWriter(holder, numFields) val baseWriter = generateStructWriter( - holder, rowWriter, expressions.map(e => StructField("", e.dataType, e.nullable))) if (!expressions.exists(_.nullable)) { @@ -83,9 +81,9 @@ class InterpretedUnsafeProjection(expressions: Array[Expression]) extends Unsafe } // Write the intermediate row to an unsafe row. - holder.reset() + rowWriter.reset() writer(intermediate) - result.setTotalSize(holder.totalSize()) + rowWriter.setTotalSize() result } } @@ -111,14 +109,13 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { * given buffer using the given [[UnsafeRowWriter]]. */ private def generateStructWriter( - bufferHolder: BufferHolder, rowWriter: UnsafeRowWriter, fields: Array[StructField]): InternalRow => Unit = { val numFields = fields.length // Create field writers. val fieldWriters = fields.map { field => - generateFieldWriter(bufferHolder, rowWriter, field.dataType, field.nullable) + generateFieldWriter(rowWriter, field.dataType, field.nullable) } // Create basic writer. row => { @@ -136,7 +133,6 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { * or array) to the given buffer using the given [[UnsafeWriter]]. */ private def generateFieldWriter( - bufferHolder: BufferHolder, writer: UnsafeWriter, dt: DataType, nullable: Boolean): (SpecializedGetters, Int) => Unit = { @@ -178,81 +174,79 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { case StructType(fields) => val numFields = fields.length - val rowWriter = new UnsafeRowWriter(bufferHolder, numFields) - val structWriter = generateStructWriter(bufferHolder, rowWriter, fields) + val rowWriter = new UnsafeRowWriter(writer, numFields) + val structWriter = generateStructWriter(rowWriter, fields) (v, i) => { - val tmpCursor = bufferHolder.cursor + val tmpCursor = rowWriter.cursor v.getStruct(i, fields.length) match { case row: UnsafeRow => writeUnsafeData( - bufferHolder, + rowWriter, row.getBaseObject, row.getBaseOffset, row.getSizeInBytes) case row => // Nested struct. We don't know where this will start because a row can be // variable length, so we need to update the offsets and zero out the bit mask. - rowWriter.reset() + rowWriter.resetRowWriter() structWriter.apply(row) } - writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor) + writer.setOffsetAndSize(i, tmpCursor, rowWriter.cursor - tmpCursor) } case ArrayType(elementType, containsNull) => - val arrayWriter = new UnsafeArrayWriter + val arrayWriter = new UnsafeArrayWriter(writer) val elementSize = getElementSize(elementType) val elementWriter = generateFieldWriter( - bufferHolder, arrayWriter, elementType, containsNull) (v, i) => { - val tmpCursor = bufferHolder.cursor - writeArray(bufferHolder, arrayWriter, elementWriter, v.getArray(i), elementSize) - writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor) + val tmpCursor = arrayWriter.cursor + writeArray(arrayWriter, elementWriter, v.getArray(i), elementSize) + writer.setOffsetAndSize(i, tmpCursor, arrayWriter.cursor - tmpCursor) } case MapType(keyType, valueType, valueContainsNull) => - val keyArrayWriter = new UnsafeArrayWriter + val keyArrayWriter = new UnsafeArrayWriter(writer) val keySize = getElementSize(keyType) val keyWriter = generateFieldWriter( - bufferHolder, keyArrayWriter, keyType, nullable = false) - val valueArrayWriter = new UnsafeArrayWriter + val valueArrayWriter = new UnsafeArrayWriter(writer) val valueSize = getElementSize(valueType) val valueWriter = generateFieldWriter( - bufferHolder, valueArrayWriter, valueType, valueContainsNull) (v, i) => { - val tmpCursor = bufferHolder.cursor + val tmpCursor = valueArrayWriter.cursor v.getMap(i) match { case map: UnsafeMapData => writeUnsafeData( - bufferHolder, + valueArrayWriter, map.getBaseObject, map.getBaseOffset, map.getSizeInBytes) case map => // preserve 8 bytes to write the key array numBytes later. - bufferHolder.grow(8) - bufferHolder.cursor += 8 + valueArrayWriter.grow(8) + valueArrayWriter.addCursor(8) // Write the keys and write the numBytes of key array into the first 8 bytes. - writeArray(bufferHolder, keyArrayWriter, keyWriter, map.keyArray(), keySize) - Platform.putLong(bufferHolder.buffer, tmpCursor, bufferHolder.cursor - tmpCursor - 8) + writeArray(keyArrayWriter, keyWriter, map.keyArray(), keySize) + Platform.putLong( + valueArrayWriter.buffer, tmpCursor, valueArrayWriter.cursor - tmpCursor - 8) // Write the values. - writeArray(bufferHolder, valueArrayWriter, valueWriter, map.valueArray(), valueSize) + writeArray(valueArrayWriter, valueWriter, map.valueArray(), valueSize) } - writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor) + writer.setOffsetAndSize(i, tmpCursor, valueArrayWriter.cursor - tmpCursor) } case udt: UserDefinedType[_] => - generateFieldWriter(bufferHolder, writer, udt.sqlType, nullable) + generateFieldWriter(writer, udt.sqlType, nullable) case NullType => (_, _) => {} @@ -324,20 +318,19 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { * copy. */ private def writeArray( - bufferHolder: BufferHolder, arrayWriter: UnsafeArrayWriter, elementWriter: (SpecializedGetters, Int) => Unit, array: ArrayData, elementSize: Int): Unit = array match { case unsafe: UnsafeArrayData => writeUnsafeData( - bufferHolder, + arrayWriter, unsafe.getBaseObject, unsafe.getBaseOffset, unsafe.getSizeInBytes) case _ => val numElements = array.numElements() - arrayWriter.initialize(bufferHolder, numElements, elementSize) + arrayWriter.initialize(numElements, elementSize) var i = 0 while (i < numElements) { elementWriter.apply(array, i) @@ -350,17 +343,17 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { * [[UnsafeRow]], [[UnsafeArrayData]] and [[UnsafeMapData]] objects. */ private def writeUnsafeData( - bufferHolder: BufferHolder, + writer: UnsafeWriter, baseObject: AnyRef, baseOffset: Long, sizeInBytes: Int) : Unit = { - bufferHolder.grow(sizeInBytes) + writer.grow(sizeInBytes) Platform.copyMemory( baseObject, baseOffset, - bufferHolder.buffer, - bufferHolder.cursor, + writer.buffer, + writer.cursor, sizeInBytes) - bufferHolder.cursor += sizeInBytes + writer.addCursor(sizeInBytes) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 6682ba55b18b..5cd0b6f90b2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -48,19 +48,23 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx: CodegenContext, input: String, fieldTypes: Seq[DataType], - bufferHolder: String): String = { + rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => ExprCode("", s"$tmpInput.isNullAt($i)", CodeGenerator.getValue(tmpInput, dt, i.toString)) } + val rowWriterClass = classOf[UnsafeRowWriter].getName + val structRowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", + v => s"$v = new $rowWriterClass($rowWriter, ${fieldEvals.length});") + s""" final InternalRow $tmpInput = $input; if ($tmpInput instanceof UnsafeRow) { - ${writeUnsafeData(ctx, s"((UnsafeRow) $tmpInput)", bufferHolder)} + ${writeUnsafeData(ctx, s"((UnsafeRow) $tmpInput)", structRowWriter)} } else { - ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, bufferHolder)} + ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, structRowWriter)} } """ } @@ -70,12 +74,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro row: String, inputs: Seq[ExprCode], inputTypes: Seq[DataType], - bufferHolder: String, + rowWriter: String, isTopLevel: Boolean = false): String = { - val rowWriterClass = classOf[UnsafeRowWriter].getName - val rowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", - v => s"$v = new $rowWriterClass($bufferHolder, ${inputs.length});") - val resetWriter = if (isTopLevel) { // For top level row writer, it always writes to the beginning of the global buffer holder, // which means its fixed-size region always in the same position, so we don't need to call @@ -88,7 +88,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s"$rowWriter.zeroOutNullBytes();" } } else { - s"$rowWriter.reset();" + s"$rowWriter.resetRowWriter();" } val writeFields = inputs.zip(inputTypes).zipWithIndex.map { @@ -111,27 +111,27 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s""" // Remember the current cursor so that we can calculate how many bytes are // written later. - final int $tmpCursor = $bufferHolder.cursor; - ${writeStructToBuffer(ctx, input.value, t.map(_.dataType), bufferHolder)} - $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + final int $tmpCursor = $rowWriter.cursor(); + ${writeStructToBuffer(ctx, input.value, t.map(_.dataType), rowWriter)} + $rowWriter.setOffsetAndSize($index, $tmpCursor, $rowWriter.cursor() - $tmpCursor); """ case a @ ArrayType(et, _) => s""" // Remember the current cursor so that we can calculate how many bytes are // written later. - final int $tmpCursor = $bufferHolder.cursor; - ${writeArrayToBuffer(ctx, input.value, et, bufferHolder)} - $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + final int $tmpCursor = $rowWriter.cursor(); + ${writeArrayToBuffer(ctx, input.value, et, rowWriter)} + $rowWriter.setOffsetAndSize($index, $tmpCursor, $rowWriter.cursor() - $tmpCursor); """ case m @ MapType(kt, vt, _) => s""" // Remember the current cursor so that we can calculate how many bytes are // written later. - final int $tmpCursor = $bufferHolder.cursor; - ${writeMapToBuffer(ctx, input.value, kt, vt, bufferHolder)} - $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + final int $tmpCursor = $rowWriter.cursor(); + ${writeMapToBuffer(ctx, input.value, kt, vt, rowWriter)} + $rowWriter.setOffsetAndSize($index, $tmpCursor, $rowWriter.cursor() - $tmpCursor); """ case t: DecimalType => @@ -181,12 +181,12 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx: CodegenContext, input: String, elementType: DataType, - bufferHolder: String): String = { + rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val arrayWriterClass = classOf[UnsafeArrayWriter].getName val arrayWriter = ctx.addMutableState(arrayWriterClass, "arrayWriter", - v => s"$v = new $arrayWriterClass();") + v => s"$v = new $arrayWriterClass($rowWriter);") val numElements = ctx.freshName("numElements") val index = ctx.freshName("index") @@ -208,23 +208,23 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val writeElement = et match { case t: StructType => s""" - final int $tmpCursor = $bufferHolder.cursor; - ${writeStructToBuffer(ctx, element, t.map(_.dataType), bufferHolder)} - $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + final int $tmpCursor = $rowWriter.cursor(); + ${writeStructToBuffer(ctx, element, t.map(_.dataType), rowWriter)} + $arrayWriter.setOffsetAndSize($index, $tmpCursor, $rowWriter.cursor() - $tmpCursor); """ case a @ ArrayType(et, _) => s""" - final int $tmpCursor = $bufferHolder.cursor; - ${writeArrayToBuffer(ctx, element, et, bufferHolder)} - $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + final int $tmpCursor = $rowWriter.cursor(); + ${writeArrayToBuffer(ctx, element, et, rowWriter)} + $arrayWriter.setOffsetAndSize($index, $tmpCursor, $rowWriter.cursor() - $tmpCursor); """ case m @ MapType(kt, vt, _) => s""" - final int $tmpCursor = $bufferHolder.cursor; - ${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)} - $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + final int $tmpCursor = $rowWriter.cursor(); + ${writeMapToBuffer(ctx, element, kt, vt, rowWriter)} + $arrayWriter.setOffsetAndSize($index, $tmpCursor, $rowWriter.cursor() - $tmpCursor); """ case t: DecimalType => @@ -240,10 +240,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s""" final ArrayData $tmpInput = $input; if ($tmpInput instanceof UnsafeArrayData) { - ${writeUnsafeData(ctx, s"((UnsafeArrayData) $tmpInput)", bufferHolder)} + ${writeUnsafeData(ctx, s"((UnsafeArrayData) $tmpInput)", arrayWriter)} } else { final int $numElements = $tmpInput.numElements(); - $arrayWriter.initialize($bufferHolder, $numElements, $elementOrOffsetSize); + $arrayWriter.initialize($numElements, $elementOrOffsetSize); for (int $index = 0; $index < $numElements; $index++) { if ($tmpInput.isNullAt($index)) { @@ -262,7 +262,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro input: String, keyType: DataType, valueType: DataType, - bufferHolder: String): String = { + rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val tmpCursor = ctx.freshName("tmpCursor") @@ -271,20 +271,20 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s""" final MapData $tmpInput = $input; if ($tmpInput instanceof UnsafeMapData) { - ${writeUnsafeData(ctx, s"((UnsafeMapData) $tmpInput)", bufferHolder)} + ${writeUnsafeData(ctx, s"((UnsafeMapData) $tmpInput)", rowWriter)} } else { // preserve 8 bytes to write the key array numBytes later. - $bufferHolder.grow(8); - $bufferHolder.cursor += 8; + $rowWriter.grow(8); + $rowWriter.addCursor(8); // Remember the current cursor so that we can write numBytes of key array later. - final int $tmpCursor = $bufferHolder.cursor; + final int $tmpCursor = $rowWriter.cursor(); - ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, bufferHolder)} + ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, rowWriter)} // Write the numBytes of key array into the first 8 bytes. - Platform.putLong($bufferHolder.buffer, $tmpCursor - 8, $bufferHolder.cursor - $tmpCursor); + Platform.putLong($rowWriter.buffer(), $tmpCursor - 8, $rowWriter.cursor() - $tmpCursor); - ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, bufferHolder)} + ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, rowWriter)} } """ } @@ -293,14 +293,14 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro * If the input is already in unsafe format, we don't need to go through all elements/fields, * we can directly write it. */ - private def writeUnsafeData(ctx: CodegenContext, input: String, bufferHolder: String) = { + private def writeUnsafeData(ctx: CodegenContext, input: String, rowWriter: String) = { val sizeInBytes = ctx.freshName("sizeInBytes") s""" final int $sizeInBytes = $input.getSizeInBytes(); // grow the global buffer before writing data. - $bufferHolder.grow($sizeInBytes); - $input.writeToMemory($bufferHolder.buffer, $bufferHolder.cursor); - $bufferHolder.cursor += $sizeInBytes; + $rowWriter.grow($sizeInBytes); + $input.writeToMemory($rowWriter.buffer(), $rowWriter.cursor()); + $rowWriter.addCursor($sizeInBytes); """ } @@ -320,26 +320,26 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val result = ctx.addMutableState("UnsafeRow", "result", v => s"$v = new UnsafeRow(${expressions.length});") - val holderClass = classOf[BufferHolder].getName - val holder = ctx.addMutableState(holderClass, "holder", - v => s"$v = new $holderClass($result, ${numVarLenFields * 32});") + val rowWriterClass = classOf[UnsafeRowWriter].getName + val rowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", + v => s"$v = new $rowWriterClass($result, ${numVarLenFields * 32});") val resetBufferHolder = if (numVarLenFields == 0) { "" } else { - s"$holder.reset();" + s"$rowWriter.reset();" } val updateRowSize = if (numVarLenFields == 0) { "" } else { - s"$result.setTotalSize($holder.totalSize());" + s"$rowWriter.setTotalSize();" } // Evaluate all the subexpression. val evalSubexpr = ctx.subexprFunctions.mkString("\n") - val writeExpressions = - writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, holder, isTopLevel = true) + val writeExpressions = writeExpressionsToBuffer( + ctx, ctx.INPUT_ROW, exprEvals, exprTypes, rowWriter, isTopLevel = true) val code = s""" diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java index fb3dbe8ed199..e3f25e6ef7d0 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java @@ -27,7 +27,6 @@ import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; import org.apache.spark.unsafe.types.UTF8String; @@ -56,34 +55,31 @@ private String getRandomString(int length) { private UnsafeRow makeKeyRow(long k1, String k2) { UnsafeRow row = new UnsafeRow(2); - BufferHolder holder = new BufferHolder(row, 32); - UnsafeRowWriter writer = new UnsafeRowWriter(holder, 2); - holder.reset(); + UnsafeRowWriter writer = new UnsafeRowWriter(row); + writer.reset(); writer.write(0, k1); writer.write(1, UTF8String.fromString(k2)); - row.setTotalSize(holder.totalSize()); + writer.setTotalSize(); return row; } private UnsafeRow makeKeyRow(long k1, long k2) { UnsafeRow row = new UnsafeRow(2); - BufferHolder holder = new BufferHolder(row, 0); - UnsafeRowWriter writer = new UnsafeRowWriter(holder, 2); - holder.reset(); + UnsafeRowWriter writer = new UnsafeRowWriter(row); + writer.reset(); writer.write(0, k1); writer.write(1, k2); - row.setTotalSize(holder.totalSize()); + writer.setTotalSize(); return row; } private UnsafeRow makeValueRow(long v1, long v2) { UnsafeRow row = new UnsafeRow(2); - BufferHolder holder = new BufferHolder(row, 0); - UnsafeRowWriter writer = new UnsafeRowWriter(holder, 2); - holder.reset(); + UnsafeRowWriter writer = new UnsafeRowWriter(row); + writer.reset(); writer.write(0, v1); writer.write(1, v2); - row.setTotalSize(holder.totalSize()); + writer.setTotalSize(); return row; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala index 8617be88f357..7bfd1972dab0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala @@ -166,17 +166,14 @@ class RowBasedHashMapGenerator( | if (numRows < capacity && !isBatchFull) { | // creating the unsafe for new entry | UnsafeRow agg_result = new UnsafeRow(${groupingKeySchema.length}); - | org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder - | = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, - | ${numVarLenFields * 32}); | org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter | = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter( - | agg_holder, - | ${groupingKeySchema.length}); - | agg_holder.reset(); //TODO: investigate if reset or zeroout are actually needed + | agg_result, + | ${numVarLenFields * 32}); + | agg_rowWriter.reset(); //TODO: investigate if reset or zeroout are actually needed | agg_rowWriter.zeroOutNullBytes(); | ${createUnsafeRowForKey}; - | agg_result.setTotalSize(agg_holder.totalSize()); + | agg_rowWriter.setTotalSize(); | Object kbase = agg_result.getBaseObject(); | long koff = agg_result.getBaseOffset(); | int klen = agg_result.getSizeInBytes(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index 3b5655ba0582..d3790dcb6c46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -166,8 +166,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera private ByteOrder nativeOrder = null; private byte[][] buffers = null; private UnsafeRow unsafeRow = new UnsafeRow($numFields); - private BufferHolder bufferHolder = new BufferHolder(unsafeRow); - private UnsafeRowWriter rowWriter = new UnsafeRowWriter(bufferHolder, $numFields); + private UnsafeRowWriter rowWriter = new UnsafeRowWriter(unsafeRow); private MutableUnsafeRow mutableRow = null; private int currentRow = 0; @@ -212,10 +211,10 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera public InternalRow next() { currentRow += 1; - bufferHolder.reset(); + rowWriter.reset(); rowWriter.zeroOutNullBytes(); ${extractorCalls} - unsafeRow.setTotalSize(bufferHolder.totalSize()); + rowWriter.setTotalSize(); return unsafeRow; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index c661e9bd3b94..a0508ce7d26a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -29,7 +29,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ @@ -134,14 +134,13 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { reader.map(_ => emptyUnsafeRow) } else { val unsafeRow = new UnsafeRow(1) - val bufferHolder = new BufferHolder(unsafeRow) - val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 1) + val unsafeRowWriter = new UnsafeRowWriter(unsafeRow) reader.map { line => // Writes to an UnsafeRow directly - bufferHolder.reset() - unsafeRowWriter.write(0, line.getBytes, 0, line.getLength) - unsafeRow.setTotalSize(bufferHolder.totalSize()) + unsafeRowWriter.reset() + unsafeRowWriter.write(0, line.getBytes) + unsafeRowWriter.setTotalSize() unsafeRow } } From 06e7435c7a9f5278f468b75605c9aedc26d0f304 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 17 Mar 2018 21:21:44 +0100 Subject: [PATCH 02/11] update comment --- .../spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index f8a78fa19555..43571061a84d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -28,7 +28,7 @@ * It will remember the offset of row buffer which it starts to write, and move the cursor of row * buffer while writing. If new data(can be the input record if this is the outermost writer, or * nested struct if this is an inner writer) comes, the starting cursor of row buffer may be - * changed, so we need to call `UnsafeRowWriter.reset` before writing, to update the + * changed, so we need to call `UnsafeRowWriter.resetRowWriter` before writing, to update the * `startingOffset` and clear out null bits. * * Note that if this is the outermost writer, which means we will always write from the very From b696b7ce34a21c2d0136480a29edc746e46201d2 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 19 Mar 2018 15:30:17 +0100 Subject: [PATCH 03/11] fix test failure --- .../spark/sql/execution/datasources/text/TextFileFormat.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index a0508ce7d26a..47d202467c85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -139,7 +139,7 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { reader.map { line => // Writes to an UnsafeRow directly unsafeRowWriter.reset() - unsafeRowWriter.write(0, line.getBytes) + unsafeRowWriter.write(0, line.getBytes, 0, line.getLength) unsafeRowWriter.setTotalSize() unsafeRow } From 760d08bf48fd562b7b84b2c1ec7836f027f6ae89 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 19 Mar 2018 17:36:03 +0100 Subject: [PATCH 04/11] address review comments --- .../KafkaRecordToUnsafeRowConverter.scala | 5 +-- .../expressions/codegen/BufferHolder.java | 14 ++++-- .../codegen/UnsafeArrayWriter.java | 26 ++++++----- .../expressions/codegen/UnsafeRowWriter.java | 38 ++++++++++------ .../expressions/codegen/UnsafeWriter.java | 45 ++++++++++++------- .../InterpretedUnsafeProjection.scala | 31 +++++-------- .../codegen/GenerateUnsafeProjection.scala | 20 ++++----- .../RowBasedKeyValueBatchSuite.java | 15 +++---- .../aggregate/RowBasedHashMapGenerator.scala | 6 +-- .../columnar/GenerateColumnAccessor.scala | 5 +-- .../datasources/text/TextFileFormat.scala | 5 +-- 11 files changed, 114 insertions(+), 96 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala index 05d2f9dc74b5..d99e7a7e57d6 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala @@ -26,8 +26,7 @@ import org.apache.spark.unsafe.types.UTF8String /** A simple class for converting Kafka ConsumerRecord to UnsafeRow */ private[kafka010] class KafkaRecordToUnsafeRowConverter { - private val sharedRow = new UnsafeRow(7) - private val rowWriter = new UnsafeRowWriter(sharedRow) + private val rowWriter = new UnsafeRowWriter(7) def toUnsafeRow(record: ConsumerRecord[Array[Byte], Array[Byte]]): UnsafeRow = { rowWriter.reset() @@ -46,6 +45,6 @@ private[kafka010] class KafkaRecordToUnsafeRowConverter { DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(record.timestamp))) rowWriter.write(6, record.timestampType.id) rowWriter.setTotalSize() - sharedRow + rowWriter.getRow() } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index 3eb3e3993c3d..b97709459fbf 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -35,7 +35,7 @@ * the size of the result row, after writing a record to the buffer. However, we can skip this step * if the fields of row are all fixed-length, as the size of result row is also fixed. */ -public final class BufferHolder { +final class BufferHolder { private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; @@ -86,11 +86,17 @@ void grow(int neededSize) { } } - byte[] buffer() { return buffer; } + byte[] buffer() { + return buffer; + } - int getCursor() { return cursor; } + int getCursor() { + return cursor; + } - void addCursor(int val) { cursor += val; } + void incrementCursor(int val) { + cursor += val; + } void reset() { cursor = Platform.BYTE_ARRAY_OFFSET + fixedSize; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java index 1a15b6cc57fd..b96798c7f6d7 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -33,6 +33,9 @@ public final class UnsafeArrayWriter extends UnsafeWriter { // The number of elements in this array private int numElements; + // The element size in this array + private int elementSize; + private int headerInBytes; private void assertIndexIsValid(int index) { @@ -40,11 +43,12 @@ private void assertIndexIsValid(int index) { assert index < numElements : "index (" + index + ") should < " + numElements; } - public UnsafeArrayWriter(UnsafeWriter writer) { + public UnsafeArrayWriter(UnsafeWriter writer, int elementSize) { super(writer.getBufferHolder()); + this.elementSize = elementSize; } - public void initialize(int numElements, int elementSize) { + public void initialize(int numElements) { // We need 8 bytes to store numElements in header this.numElements = numElements; this.headerInBytes = calculateHeaderPortionInBytes(numElements); @@ -66,7 +70,7 @@ public void initialize(int numElements, int elementSize) { for (int i = elementSize * numElements; i < fixedPartInBytes; i++) { Platform.putByte(buffer(), startingOffset + headerInBytes + i, (byte) 0); } - addCursor(headerInBytes + fixedPartInBytes); + incrementCursor(headerInBytes + fixedPartInBytes); } protected long getOffset(int ordinal, int elementSize) { @@ -116,37 +120,37 @@ public void setNull8Bytes(int ordinal) { public void write(int ordinal, boolean value) { assertIndexIsValid(ordinal); - _write(getElementOffset(ordinal, 1), value); + writeBoolean(getElementOffset(ordinal, 1), value); } public void write(int ordinal, byte value) { assertIndexIsValid(ordinal); - _write(getElementOffset(ordinal, 1), value); + writeByte(getElementOffset(ordinal, 1), value); } public void write(int ordinal, short value) { assertIndexIsValid(ordinal); - _write(getElementOffset(ordinal, 2), value); + writeShort(getElementOffset(ordinal, 2), value); } public void write(int ordinal, int value) { assertIndexIsValid(ordinal); - _write(getElementOffset(ordinal, 4), value); + writeInt(getElementOffset(ordinal, 4), value); } public void write(int ordinal, long value) { assertIndexIsValid(ordinal); - _write(getElementOffset(ordinal, 8), value); + writeLong(getElementOffset(ordinal, 8), value); } public void write(int ordinal, float value) { assertIndexIsValid(ordinal); - _write(getElementOffset(ordinal, 4), value); + writeFloat(getElementOffset(ordinal, 4), value); } public void write(int ordinal, double value) { assertIndexIsValid(ordinal); - _write(getElementOffset(ordinal, 8), value); + writeDouble(getElementOffset(ordinal, 8), value); } public void write(int ordinal, Decimal input, int precision, int scale) { @@ -170,7 +174,7 @@ public void write(int ordinal, Decimal input, int precision, int scale) { setOffsetAndSize(ordinal, numBytes); // move the cursor forward with 8-bytes boundary - addCursor(roundedSize); + incrementCursor(roundedSize); } } else { setNull(ordinal); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index 43571061a84d..f85f37bca3e9 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -42,18 +42,26 @@ public final class UnsafeRowWriter extends UnsafeWriter { private final int nullBitsSize; private final int fixedSize; - public UnsafeRowWriter(UnsafeRow row, int initialBufferSize) { - this(row, new BufferHolder(row, initialBufferSize), row.numFields()); + public UnsafeRowWriter(int numFields) { + this(new UnsafeRow(numFields)); } - public UnsafeRowWriter(UnsafeRow row) { - this(row, new BufferHolder(row), row.numFields()); + public UnsafeRowWriter(int numFields, int initialBufferSize) { + this(new UnsafeRow(numFields), initialBufferSize); } public UnsafeRowWriter(UnsafeWriter writer, int numFields) { this(null, writer.getBufferHolder(), numFields); } + private UnsafeRowWriter(UnsafeRow row) { + this(row, new BufferHolder(row), row.numFields()); + } + + private UnsafeRowWriter(UnsafeRow row, int initialBufferSize) { + this(row, new BufferHolder(row, initialBufferSize), row.numFields()); + } + private UnsafeRowWriter(UnsafeRow row, BufferHolder holder, int numFields) { super(holder); this.row = row; @@ -62,6 +70,10 @@ private UnsafeRowWriter(UnsafeRow row, BufferHolder holder, int numFields) { this.startingOffset = cursor(); } + public UnsafeRow getRow() { + return row; + } + public void setTotalSize() { row.setTotalSize(totalSize()); } @@ -75,7 +87,7 @@ public void resetRowWriter() { // grow the global buffer to make sure it has enough space to write fixed-length data. grow(fixedSize); - addCursor(fixedSize); + incrementCursor(fixedSize); zeroOutNullBytes(); } @@ -135,39 +147,39 @@ public void setOffsetAndSize(int ordinal, int currentCursor, int size) { public void write(int ordinal, boolean value) { final long offset = getFieldOffset(ordinal); Platform.putLong(buffer(), offset, 0L); - _write(offset, value); + writeBoolean(offset, value); } public void write(int ordinal, byte value) { final long offset = getFieldOffset(ordinal); Platform.putLong(buffer(), offset, 0L); - _write(offset, value); + writeByte(offset, value); } public void write(int ordinal, short value) { final long offset = getFieldOffset(ordinal); Platform.putLong(buffer(), offset, 0L); - _write(offset, value); + writeShort(offset, value); } public void write(int ordinal, int value) { final long offset = getFieldOffset(ordinal); Platform.putLong(buffer(), offset, 0L); - _write(offset, value); + writeInt(offset, value); } public void write(int ordinal, long value) { - _write(getFieldOffset(ordinal), value); + writeLong(getFieldOffset(ordinal), value); } public void write(int ordinal, float value) { final long offset = getFieldOffset(ordinal); Platform.putLong(buffer(), offset, 0L); - _write(offset, value); + writeFloat(offset, value); } public void write(int ordinal, double value) { - _write(getFieldOffset(ordinal), value); + writeDouble(getFieldOffset(ordinal), value); } public void write(int ordinal, Decimal input, int precision, int scale) { @@ -206,7 +218,7 @@ public void write(int ordinal, Decimal input, int precision, int scale) { } // move the cursor forward. - addCursor(16); + incrementCursor(16); } } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java index 24473fa5f87f..8fdc60b5fa5d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -43,18 +43,29 @@ public final BufferHolder getBufferHolder() { return holder; } - public final byte[] buffer() { return holder.buffer(); } - - public final void reset() { holder.reset(); } + public final byte[] buffer() { + return holder.buffer(); + } - public final int totalSize() { return holder.totalSize(); } + public final void reset() { + holder.reset(); + } - public final void grow(int neededSize) { holder.grow(neededSize); } + public final int totalSize() { + return holder.totalSize(); + } - public final int cursor() { return holder.getCursor(); } + public final void grow(int neededSize) { + holder.grow(neededSize); + } - public final void addCursor(int val) { holder.addCursor(val); } + public final int cursor() { + return holder.getCursor(); + } + public final void incrementCursor(int val) { + holder.incrementCursor(val); + } public abstract void setOffsetAndSize(int ordinal, int currentCursor, int size); @@ -106,7 +117,7 @@ public final void write(int ordinal, UTF8String input) { setOffsetAndSize(ordinal, numBytes); // move the cursor forward. - addCursor(roundedSize); + incrementCursor(roundedSize); } public final void write(int ordinal, byte[] input) { @@ -128,7 +139,7 @@ public final void write(int ordinal, byte[] input, int offset, int numBytes) { setOffsetAndSize(ordinal, numBytes); // move the cursor forward. - addCursor(roundedSize); + incrementCursor(roundedSize); } public final void write(int ordinal, CalendarInterval input) { @@ -142,37 +153,37 @@ public final void write(int ordinal, CalendarInterval input) { setOffsetAndSize(ordinal, 16); // move the cursor forward. - addCursor(16); + incrementCursor(16); } - protected final void _write(long offset, boolean value) { + protected final void writeBoolean(long offset, boolean value) { Platform.putBoolean(buffer(), offset, value); } - protected final void _write(long offset, byte value) { + protected final void writeByte(long offset, byte value) { Platform.putByte(buffer(), offset, value); } - protected final void _write(long offset, short value) { + protected final void writeShort(long offset, short value) { Platform.putShort(buffer(), offset, value); } - protected final void _write(long offset, int value) { + protected final void writeInt(long offset, int value) { Platform.putInt(buffer(), offset, value); } - protected final void _write(long offset, long value) { + protected final void writeLong(long offset, long value) { Platform.putLong(buffer(), offset, value); } - protected final void _write(long offset, float value) { + protected final void writeFloat(long offset, float value) { if (Float.isNaN(value)) { value = Float.NaN; } Platform.putFloat(buffer(), offset, value); } - protected final void _write(long offset, double value) { + protected final void writeDouble(long offset, double value) { if (Double.isNaN(value)) { value = Double.NaN; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index 9ba67757de8e..4d93221c0288 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -42,11 +42,8 @@ class InterpretedUnsafeProjection(expressions: Array[Expression]) extends Unsafe /** The row representing the expression results. */ private[this] val intermediate = new GenericInternalRow(values) - /** The row returned by the projection. */ - private[this] val result = new UnsafeRow(numFields) - /* The row writer for UnsafeRow result */ - private[this] val rowWriter = new UnsafeRowWriter(result, numFields * 32) + private[this] val rowWriter = new UnsafeRowWriter(numFields, numFields * 32) /** The writer that writes the intermediate result to the result row. */ private[this] val writer: InternalRow => Unit = { @@ -84,7 +81,7 @@ class InterpretedUnsafeProjection(expressions: Array[Expression]) extends Unsafe rowWriter.reset() writer(intermediate) rowWriter.setTotalSize() - result + rowWriter.getRow() } } @@ -195,27 +192,24 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { } case ArrayType(elementType, containsNull) => - val arrayWriter = new UnsafeArrayWriter(writer) - val elementSize = getElementSize(elementType) + val arrayWriter = new UnsafeArrayWriter(writer, getElementSize(elementType)) val elementWriter = generateFieldWriter( arrayWriter, elementType, containsNull) (v, i) => { val tmpCursor = arrayWriter.cursor - writeArray(arrayWriter, elementWriter, v.getArray(i), elementSize) + writeArray(arrayWriter, elementWriter, v.getArray(i)) writer.setOffsetAndSize(i, tmpCursor, arrayWriter.cursor - tmpCursor) } case MapType(keyType, valueType, valueContainsNull) => - val keyArrayWriter = new UnsafeArrayWriter(writer) - val keySize = getElementSize(keyType) + val keyArrayWriter = new UnsafeArrayWriter(writer, getElementSize(keyType)) val keyWriter = generateFieldWriter( keyArrayWriter, keyType, nullable = false) - val valueArrayWriter = new UnsafeArrayWriter(writer) - val valueSize = getElementSize(valueType) + val valueArrayWriter = new UnsafeArrayWriter(writer, getElementSize(valueType)) val valueWriter = generateFieldWriter( valueArrayWriter, valueType, @@ -232,15 +226,15 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { case map => // preserve 8 bytes to write the key array numBytes later. valueArrayWriter.grow(8) - valueArrayWriter.addCursor(8) + valueArrayWriter.incrementCursor(8) // Write the keys and write the numBytes of key array into the first 8 bytes. - writeArray(keyArrayWriter, keyWriter, map.keyArray(), keySize) + writeArray(keyArrayWriter, keyWriter, map.keyArray()) Platform.putLong( valueArrayWriter.buffer, tmpCursor, valueArrayWriter.cursor - tmpCursor - 8) // Write the values. - writeArray(valueArrayWriter, valueWriter, map.valueArray(), valueSize) + writeArray(valueArrayWriter, valueWriter, map.valueArray()) } writer.setOffsetAndSize(i, tmpCursor, valueArrayWriter.cursor - tmpCursor) } @@ -320,8 +314,7 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { private def writeArray( arrayWriter: UnsafeArrayWriter, elementWriter: (SpecializedGetters, Int) => Unit, - array: ArrayData, - elementSize: Int): Unit = array match { + array: ArrayData): Unit = array match { case unsafe: UnsafeArrayData => writeUnsafeData( arrayWriter, @@ -330,7 +323,7 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { unsafe.getSizeInBytes) case _ => val numElements = array.numElements() - arrayWriter.initialize(numElements, elementSize) + arrayWriter.initialize(numElements) var i = 0 while (i < numElements) { elementWriter.apply(array, i) @@ -354,6 +347,6 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { writer.buffer, writer.cursor, sizeInBytes) - writer.addCursor(sizeInBytes) + writer.incrementCursor(sizeInBytes) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 5cd0b6f90b2e..35089544e931 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -184,9 +184,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") - val arrayWriterClass = classOf[UnsafeArrayWriter].getName - val arrayWriter = ctx.addMutableState(arrayWriterClass, "arrayWriter", - v => s"$v = new $arrayWriterClass($rowWriter);") val numElements = ctx.freshName("numElements") val index = ctx.freshName("index") @@ -203,6 +200,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => 8 // we need 8 bytes to store offset and length } + val arrayWriterClass = classOf[UnsafeArrayWriter].getName + val arrayWriter = ctx.addMutableState(arrayWriterClass, "arrayWriter", + v => s"$v = new $arrayWriterClass($rowWriter, $elementOrOffsetSize);") + val tmpCursor = ctx.freshName("tmpCursor") val element = CodeGenerator.getValue(tmpInput, et, index) val writeElement = et match { @@ -243,7 +244,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ${writeUnsafeData(ctx, s"((UnsafeArrayData) $tmpInput)", arrayWriter)} } else { final int $numElements = $tmpInput.numElements(); - $arrayWriter.initialize($numElements, $elementOrOffsetSize); + $arrayWriter.initialize($numElements); for (int $index = 0; $index < $numElements; $index++) { if ($tmpInput.isNullAt($index)) { @@ -275,7 +276,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } else { // preserve 8 bytes to write the key array numBytes later. $rowWriter.grow(8); - $rowWriter.addCursor(8); + $rowWriter.incrementCursor(8); // Remember the current cursor so that we can write numBytes of key array later. final int $tmpCursor = $rowWriter.cursor(); @@ -300,7 +301,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // grow the global buffer before writing data. $rowWriter.grow($sizeInBytes); $input.writeToMemory($rowWriter.buffer(), $rowWriter.cursor()); - $rowWriter.addCursor($sizeInBytes); + $rowWriter.incrementCursor($sizeInBytes); """ } @@ -317,12 +318,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => true } - val result = ctx.addMutableState("UnsafeRow", "result", - v => s"$v = new UnsafeRow(${expressions.length});") - val rowWriterClass = classOf[UnsafeRowWriter].getName val rowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", - v => s"$v = new $rowWriterClass($result, ${numVarLenFields * 32});") + v => s"$v = new $rowWriterClass(${expressions.length}, ${numVarLenFields * 32});") val resetBufferHolder = if (numVarLenFields == 0) { "" @@ -348,7 +346,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $writeExpressions $updateRowSize """ - ExprCode(code, "false", result) + ExprCode(code, "false", s"$rowWriter.getRow()") } protected def canonicalize(in: Seq[Expression]): Seq[Expression] = diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java index e3f25e6ef7d0..66866880d41b 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java @@ -54,33 +54,30 @@ private String getRandomString(int length) { } private UnsafeRow makeKeyRow(long k1, String k2) { - UnsafeRow row = new UnsafeRow(2); - UnsafeRowWriter writer = new UnsafeRowWriter(row); + UnsafeRowWriter writer = new UnsafeRowWriter(2); writer.reset(); writer.write(0, k1); writer.write(1, UTF8String.fromString(k2)); writer.setTotalSize(); - return row; + return writer.getRow(); } private UnsafeRow makeKeyRow(long k1, long k2) { - UnsafeRow row = new UnsafeRow(2); - UnsafeRowWriter writer = new UnsafeRowWriter(row); + UnsafeRowWriter writer = new UnsafeRowWriter(2); writer.reset(); writer.write(0, k1); writer.write(1, k2); writer.setTotalSize(); - return row; + return writer.getRow(); } private UnsafeRow makeValueRow(long v1, long v2) { - UnsafeRow row = new UnsafeRow(2); - UnsafeRowWriter writer = new UnsafeRowWriter(row); + UnsafeRowWriter writer = new UnsafeRowWriter(2); writer.reset(); writer.write(0, v1); writer.write(1, v2); writer.setTotalSize(); - return row; + return writer.getRow(); } private UnsafeRow appendRow(RowBasedKeyValueBatch batch, UnsafeRow key, UnsafeRow value) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala index 7bfd1972dab0..3990f7bd7688 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala @@ -165,15 +165,15 @@ class RowBasedHashMapGenerator( | if (buckets[idx] == -1) { | if (numRows < capacity && !isBatchFull) { | // creating the unsafe for new entry - | UnsafeRow agg_result = new UnsafeRow(${groupingKeySchema.length}); | org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter | = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter( - | agg_result, - | ${numVarLenFields * 32}); + | ${groupingKeySchema.length}, ${numVarLenFields * 32}); | agg_rowWriter.reset(); //TODO: investigate if reset or zeroout are actually needed | agg_rowWriter.zeroOutNullBytes(); | ${createUnsafeRowForKey}; | agg_rowWriter.setTotalSize(); + | org.apache.spark.sql.catalyst.expressions.UnsafeRow agg_result + | = agg_rowWriter.getRow(); | Object kbase = agg_result.getBaseObject(); | long koff = agg_result.getBaseOffset(); | int klen = agg_result.getSizeInBytes(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index d3790dcb6c46..6bbe19e702b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -165,8 +165,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera private ByteOrder nativeOrder = null; private byte[][] buffers = null; - private UnsafeRow unsafeRow = new UnsafeRow($numFields); - private UnsafeRowWriter rowWriter = new UnsafeRowWriter(unsafeRow); + private UnsafeRowWriter rowWriter = new UnsafeRowWriter($numFields); private MutableUnsafeRow mutableRow = null; private int currentRow = 0; @@ -215,7 +214,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera rowWriter.zeroOutNullBytes(); ${extractorCalls} rowWriter.setTotalSize(); - return unsafeRow; + return rowWriter.getRow(); } ${ctx.declareAddedFunctions()} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index 47d202467c85..f9d833eb5d32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -133,15 +133,14 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { val emptyUnsafeRow = new UnsafeRow(0) reader.map(_ => emptyUnsafeRow) } else { - val unsafeRow = new UnsafeRow(1) - val unsafeRowWriter = new UnsafeRowWriter(unsafeRow) + val unsafeRowWriter = new UnsafeRowWriter(1) reader.map { line => // Writes to an UnsafeRow directly unsafeRowWriter.reset() unsafeRowWriter.write(0, line.getBytes, 0, line.getLength) unsafeRowWriter.setTotalSize() - unsafeRow + unsafeRowWriter.getRow() } } } From c342f0da350f53ad0683bda115174d4debbcf1e0 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 19 Mar 2018 18:46:41 +0100 Subject: [PATCH 05/11] address review comment --- .../expressions/codegen/BufferHolder.java | 18 +++++++++++++ .../codegen/UnsafeArrayWriter.java | 4 +-- .../expressions/codegen/UnsafeRowWriter.java | 8 +++--- .../expressions/codegen/UnsafeWriter.java | 13 ++++++++-- .../InterpretedUnsafeProjection.scala | 12 ++++----- .../codegen/GenerateUnsafeProjection.scala | 26 +++++++++---------- 6 files changed, 53 insertions(+), 28 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index b97709459fbf..1fca522924ed 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -43,6 +43,8 @@ final class BufferHolder { private int cursor = Platform.BYTE_ARRAY_OFFSET; private final UnsafeRow row; private final int fixedSize; + private int[] cursorStack = new int[1]; + private int cursorStackIndex = 0; BufferHolder(UnsafeRow row) { this(row, 64); @@ -98,6 +100,22 @@ void incrementCursor(int val) { cursor += val; } + int pushCursor() { + if (cursorStack.length <= cursorStackIndex) { + int newSize = (cursorStack.length * 3 + 1) / 2; + int[] tmp = new int[newSize]; + System.arraycopy(cursorStack, 0, tmp, 0, cursorStack.length); + cursorStack = tmp; + } + int cur = getCursor(); + cursorStack[cursorStackIndex++] = cur; + return cursor; + } + + int popCursor() { + return cursorStack[--cursorStackIndex]; + } + void reset() { cursor = Platform.BYTE_ARRAY_OFFSET + fixedSize; } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java index b96798c7f6d7..7577edfc71ec 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -82,9 +82,9 @@ private long getElementOffset(int ordinal, int elementSize) { } @Override - public void setOffsetAndSize(int ordinal, int currentCursor, int size) { + public void setOffsetAndSizeFromMark(int ordinal) { assertIndexIsValid(ordinal); - _setOffsetAndSize(ordinal, currentCursor, size); + _setOffsetAndSizeFromMark(ordinal); } private void setNullBit(int ordinal) { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index f85f37bca3e9..effce98ab4c5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -131,8 +131,8 @@ public void setNull8Bytes(int ordinal) { } @Override - protected final long getOffset(int oridinal, int elementSize) { - return getFieldOffset(oridinal); + protected final long getOffset(int ordinal, int elementSize) { + return getFieldOffset(ordinal); } public long getFieldOffset(int ordinal) { @@ -140,8 +140,8 @@ public long getFieldOffset(int ordinal) { } @Override - public void setOffsetAndSize(int ordinal, int currentCursor, int size) { - _setOffsetAndSize(ordinal, currentCursor, size); + public void setOffsetAndSizeFromMark(int ordinal) { + _setOffsetAndSizeFromMark(ordinal); } public void write(int ordinal, boolean value) { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java index 8fdc60b5fa5d..0c3d05d85612 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -67,13 +67,22 @@ public final void incrementCursor(int val) { holder.incrementCursor(val); } - public abstract void setOffsetAndSize(int ordinal, int currentCursor, int size); + public final int markCursor() { + return holder.pushCursor(); + } + + public abstract void setOffsetAndSizeFromMark(int ordinal); + + protected void _setOffsetAndSizeFromMark(int ordinal) { + int mark = holder.popCursor(); + setOffsetAndSize(ordinal, mark, cursor() - mark); + } protected void setOffsetAndSize(int ordinal, int size) { setOffsetAndSize(ordinal, cursor(), size); } - protected void _setOffsetAndSize(int ordinal, int currentCursor, int size) { + protected void setOffsetAndSize(int ordinal, int currentCursor, int size) { final long relativeOffset = currentCursor - startingOffset; final long offsetAndSize = (relativeOffset << 32) | (long)size; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index 4d93221c0288..5c782e9dc8e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -174,7 +174,7 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { val rowWriter = new UnsafeRowWriter(writer, numFields) val structWriter = generateStructWriter(rowWriter, fields) (v, i) => { - val tmpCursor = rowWriter.cursor + rowWriter.markCursor() v.getStruct(i, fields.length) match { case row: UnsafeRow => writeUnsafeData( @@ -188,7 +188,7 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { rowWriter.resetRowWriter() structWriter.apply(row) } - writer.setOffsetAndSize(i, tmpCursor, rowWriter.cursor - tmpCursor) + writer.setOffsetAndSizeFromMark(i) } case ArrayType(elementType, containsNull) => @@ -198,9 +198,9 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { elementType, containsNull) (v, i) => { - val tmpCursor = arrayWriter.cursor + arrayWriter.markCursor() writeArray(arrayWriter, elementWriter, v.getArray(i)) - writer.setOffsetAndSize(i, tmpCursor, arrayWriter.cursor - tmpCursor) + writer.setOffsetAndSizeFromMark(i) } case MapType(keyType, valueType, valueContainsNull) => @@ -215,7 +215,7 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { valueType, valueContainsNull) (v, i) => { - val tmpCursor = valueArrayWriter.cursor + val tmpCursor = valueArrayWriter.markCursor() v.getMap(i) match { case map: UnsafeMapData => writeUnsafeData( @@ -236,7 +236,7 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { // Write the values. writeArray(valueArrayWriter, valueWriter, map.valueArray()) } - writer.setOffsetAndSize(i, tmpCursor, valueArrayWriter.cursor - tmpCursor) + writer.setOffsetAndSizeFromMark(i) } case udt: UserDefinedType[_] => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 35089544e931..40b6b8d9f0ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -97,7 +97,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case udt: UserDefinedType[_] => udt.sqlType case other => other } - val tmpCursor = ctx.freshName("tmpCursor") val setNull = dt match { case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => @@ -111,27 +110,27 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s""" // Remember the current cursor so that we can calculate how many bytes are // written later. - final int $tmpCursor = $rowWriter.cursor(); + $rowWriter.markCursor(); ${writeStructToBuffer(ctx, input.value, t.map(_.dataType), rowWriter)} - $rowWriter.setOffsetAndSize($index, $tmpCursor, $rowWriter.cursor() - $tmpCursor); + $rowWriter.setOffsetAndSizeFromMark($index); """ case a @ ArrayType(et, _) => s""" // Remember the current cursor so that we can calculate how many bytes are // written later. - final int $tmpCursor = $rowWriter.cursor(); + $rowWriter.markCursor(); ${writeArrayToBuffer(ctx, input.value, et, rowWriter)} - $rowWriter.setOffsetAndSize($index, $tmpCursor, $rowWriter.cursor() - $tmpCursor); + $rowWriter.setOffsetAndSizeFromMark($index); """ case m @ MapType(kt, vt, _) => s""" // Remember the current cursor so that we can calculate how many bytes are // written later. - final int $tmpCursor = $rowWriter.cursor(); + $rowWriter.markCursor(); ${writeMapToBuffer(ctx, input.value, kt, vt, rowWriter)} - $rowWriter.setOffsetAndSize($index, $tmpCursor, $rowWriter.cursor() - $tmpCursor); + $rowWriter.setOffsetAndSizeFromMark($index); """ case t: DecimalType => @@ -204,28 +203,27 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val arrayWriter = ctx.addMutableState(arrayWriterClass, "arrayWriter", v => s"$v = new $arrayWriterClass($rowWriter, $elementOrOffsetSize);") - val tmpCursor = ctx.freshName("tmpCursor") val element = CodeGenerator.getValue(tmpInput, et, index) val writeElement = et match { case t: StructType => s""" - final int $tmpCursor = $rowWriter.cursor(); + $rowWriter.markCursor(); ${writeStructToBuffer(ctx, element, t.map(_.dataType), rowWriter)} - $arrayWriter.setOffsetAndSize($index, $tmpCursor, $rowWriter.cursor() - $tmpCursor); + $arrayWriter.setOffsetAndSizeFromMark($index); """ case a @ ArrayType(et, _) => s""" - final int $tmpCursor = $rowWriter.cursor(); + $rowWriter.markCursor(); ${writeArrayToBuffer(ctx, element, et, rowWriter)} - $arrayWriter.setOffsetAndSize($index, $tmpCursor, $rowWriter.cursor() - $tmpCursor); + $arrayWriter.setOffsetAndSizeFromMark($index); """ case m @ MapType(kt, vt, _) => s""" - final int $tmpCursor = $rowWriter.cursor(); + $rowWriter.markCursor(); ${writeMapToBuffer(ctx, element, kt, vt, rowWriter)} - $arrayWriter.setOffsetAndSize($index, $tmpCursor, $rowWriter.cursor() - $tmpCursor); + $arrayWriter.setOffsetAndSizeFromMark($index); """ case t: DecimalType => From 3637a5c171ab856051b64bdd3fe01d40c5b2b569 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 20 Mar 2018 09:02:06 +0100 Subject: [PATCH 06/11] refinements --- .../catalyst/expressions/codegen/BufferHolder.java | 3 +-- .../expressions/InterpretedUnsafeProjection.scala | 6 +++--- .../codegen/GenerateUnsafeProjection.scala | 12 ++++++------ 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index 1fca522924ed..76b74bb92e2a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -102,8 +102,7 @@ void incrementCursor(int val) { int pushCursor() { if (cursorStack.length <= cursorStackIndex) { - int newSize = (cursorStack.length * 3 + 1) / 2; - int[] tmp = new int[newSize]; + int[] tmp = new int[(cursorStack.length * 3 + 1) / 2]; System.arraycopy(cursorStack, 0, tmp, 0, cursorStack.length); cursorStack = tmp; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index 5c782e9dc8e1..bf1c112629a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -174,7 +174,7 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { val rowWriter = new UnsafeRowWriter(writer, numFields) val structWriter = generateStructWriter(rowWriter, fields) (v, i) => { - rowWriter.markCursor() + writer.markCursor() v.getStruct(i, fields.length) match { case row: UnsafeRow => writeUnsafeData( @@ -198,7 +198,7 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { elementType, containsNull) (v, i) => { - arrayWriter.markCursor() + writer.markCursor() writeArray(arrayWriter, elementWriter, v.getArray(i)) writer.setOffsetAndSizeFromMark(i) } @@ -215,7 +215,7 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { valueType, valueContainsNull) (v, i) => { - val tmpCursor = valueArrayWriter.markCursor() + val tmpCursor = writer.markCursor() v.getMap(i) match { case map: UnsafeMapData => writeUnsafeData( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 40b6b8d9f0ee..505d7bbdb314 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -207,22 +207,22 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val writeElement = et match { case t: StructType => s""" - $rowWriter.markCursor(); - ${writeStructToBuffer(ctx, element, t.map(_.dataType), rowWriter)} + $arrayWriter.markCursor(); + ${writeStructToBuffer(ctx, element, t.map(_.dataType), arrayWriter)} $arrayWriter.setOffsetAndSizeFromMark($index); """ case a @ ArrayType(et, _) => s""" - $rowWriter.markCursor(); - ${writeArrayToBuffer(ctx, element, et, rowWriter)} + $arrayWriter.markCursor(); + ${writeArrayToBuffer(ctx, element, et, arrayWriter)} $arrayWriter.setOffsetAndSizeFromMark($index); """ case m @ MapType(kt, vt, _) => s""" - $rowWriter.markCursor(); - ${writeMapToBuffer(ctx, element, kt, vt, rowWriter)} + $arrayWriter.markCursor(); + ${writeMapToBuffer(ctx, element, kt, vt, arrayWriter)} $arrayWriter.setOffsetAndSizeFromMark($index); """ From 3ad76384182235f043231261ff6f65a71e2a0bf4 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 28 Mar 2018 20:40:05 +0100 Subject: [PATCH 07/11] address review comment --- .../expressions/codegen/BufferHolder.java | 17 ------------ .../codegen/UnsafeArrayWriter.java | 4 +-- .../expressions/codegen/UnsafeRowWriter.java | 4 +-- .../expressions/codegen/UnsafeWriter.java | 9 ++----- .../InterpretedUnsafeProjection.scala | 14 +++++----- .../codegen/GenerateUnsafeProjection.scala | 26 ++++++++++--------- 6 files changed, 27 insertions(+), 47 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index 76b74bb92e2a..b97709459fbf 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -43,8 +43,6 @@ final class BufferHolder { private int cursor = Platform.BYTE_ARRAY_OFFSET; private final UnsafeRow row; private final int fixedSize; - private int[] cursorStack = new int[1]; - private int cursorStackIndex = 0; BufferHolder(UnsafeRow row) { this(row, 64); @@ -100,21 +98,6 @@ void incrementCursor(int val) { cursor += val; } - int pushCursor() { - if (cursorStack.length <= cursorStackIndex) { - int[] tmp = new int[(cursorStack.length * 3 + 1) / 2]; - System.arraycopy(cursorStack, 0, tmp, 0, cursorStack.length); - cursorStack = tmp; - } - int cur = getCursor(); - cursorStack[cursorStackIndex++] = cur; - return cursor; - } - - int popCursor() { - return cursorStack[--cursorStackIndex]; - } - void reset() { cursor = Platform.BYTE_ARRAY_OFFSET + fixedSize; } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java index 7577edfc71ec..519c57494f90 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -82,9 +82,9 @@ private long getElementOffset(int ordinal, int elementSize) { } @Override - public void setOffsetAndSizeFromMark(int ordinal) { + public void setOffsetAndSizeFromMark(int ordinal, int mark) { assertIndexIsValid(ordinal); - _setOffsetAndSizeFromMark(ordinal); + _setOffsetAndSizeFromMark(ordinal, mark); } private void setNullBit(int ordinal) { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index effce98ab4c5..1c87e47bfa8e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -140,8 +140,8 @@ public long getFieldOffset(int ordinal) { } @Override - public void setOffsetAndSizeFromMark(int ordinal) { - _setOffsetAndSizeFromMark(ordinal); + public void setOffsetAndSizeFromMark(int ordinal, int mark) { + _setOffsetAndSizeFromMark(ordinal, mark); } public void write(int ordinal, boolean value) { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java index 0c3d05d85612..3d5ed1e42d14 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -67,14 +67,9 @@ public final void incrementCursor(int val) { holder.incrementCursor(val); } - public final int markCursor() { - return holder.pushCursor(); - } - - public abstract void setOffsetAndSizeFromMark(int ordinal); + public abstract void setOffsetAndSizeFromMark(int ordinal, int mark); - protected void _setOffsetAndSizeFromMark(int ordinal) { - int mark = holder.popCursor(); + protected void _setOffsetAndSizeFromMark(int ordinal, int mark) { setOffsetAndSize(ordinal, mark, cursor() - mark); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index bf1c112629a8..f285b18d7ac5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -174,7 +174,7 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { val rowWriter = new UnsafeRowWriter(writer, numFields) val structWriter = generateStructWriter(rowWriter, fields) (v, i) => { - writer.markCursor() + val markCursor = writer.cursor() v.getStruct(i, fields.length) match { case row: UnsafeRow => writeUnsafeData( @@ -188,7 +188,7 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { rowWriter.resetRowWriter() structWriter.apply(row) } - writer.setOffsetAndSizeFromMark(i) + writer.setOffsetAndSizeFromMark(i, markCursor) } case ArrayType(elementType, containsNull) => @@ -198,9 +198,9 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { elementType, containsNull) (v, i) => { - writer.markCursor() + val markCursor = writer.cursor() writeArray(arrayWriter, elementWriter, v.getArray(i)) - writer.setOffsetAndSizeFromMark(i) + writer.setOffsetAndSizeFromMark(i, markCursor) } case MapType(keyType, valueType, valueContainsNull) => @@ -215,7 +215,7 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { valueType, valueContainsNull) (v, i) => { - val tmpCursor = writer.markCursor() + val markCursor = writer.cursor() v.getMap(i) match { case map: UnsafeMapData => writeUnsafeData( @@ -231,12 +231,12 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { // Write the keys and write the numBytes of key array into the first 8 bytes. writeArray(keyArrayWriter, keyWriter, map.keyArray()) Platform.putLong( - valueArrayWriter.buffer, tmpCursor, valueArrayWriter.cursor - tmpCursor - 8) + valueArrayWriter.buffer, markCursor, valueArrayWriter.cursor - markCursor - 8) // Write the values. writeArray(valueArrayWriter, valueWriter, map.valueArray()) } - writer.setOffsetAndSizeFromMark(i) + writer.setOffsetAndSizeFromMark(i, markCursor) } case udt: UserDefinedType[_] => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 505d7bbdb314..7e1b6b6f55f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -104,33 +104,34 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s"$rowWriter.write($index, (Decimal) null, ${t.precision}, ${t.scale});" case _ => s"$rowWriter.setNullAt($index);" } + val markCursor = ctx.freshName("markCursor") val writeField = dt match { case t: StructType => s""" // Remember the current cursor so that we can calculate how many bytes are // written later. - $rowWriter.markCursor(); + final int $markCursor = $rowWriter.cursor(); ${writeStructToBuffer(ctx, input.value, t.map(_.dataType), rowWriter)} - $rowWriter.setOffsetAndSizeFromMark($index); + $rowWriter.setOffsetAndSizeFromMark($index, $markCursor); """ case a @ ArrayType(et, _) => s""" // Remember the current cursor so that we can calculate how many bytes are // written later. - $rowWriter.markCursor(); + final int $markCursor = $rowWriter.cursor(); ${writeArrayToBuffer(ctx, input.value, et, rowWriter)} - $rowWriter.setOffsetAndSizeFromMark($index); + $rowWriter.setOffsetAndSizeFromMark($index, $markCursor); """ case m @ MapType(kt, vt, _) => s""" // Remember the current cursor so that we can calculate how many bytes are // written later. - $rowWriter.markCursor(); + final int $markCursor = $rowWriter.cursor(); ${writeMapToBuffer(ctx, input.value, kt, vt, rowWriter)} - $rowWriter.setOffsetAndSizeFromMark($index); + $rowWriter.setOffsetAndSizeFromMark($index, $markCursor); """ case t: DecimalType => @@ -202,28 +203,29 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val arrayWriterClass = classOf[UnsafeArrayWriter].getName val arrayWriter = ctx.addMutableState(arrayWriterClass, "arrayWriter", v => s"$v = new $arrayWriterClass($rowWriter, $elementOrOffsetSize);") + val markCursor = ctx.freshName("markCursor") val element = CodeGenerator.getValue(tmpInput, et, index) val writeElement = et match { case t: StructType => s""" - $arrayWriter.markCursor(); + final int $markCursor = $arrayWriter.cursor(); ${writeStructToBuffer(ctx, element, t.map(_.dataType), arrayWriter)} - $arrayWriter.setOffsetAndSizeFromMark($index); + $arrayWriter.setOffsetAndSizeFromMark($index, $markCursor); """ case a @ ArrayType(et, _) => s""" - $arrayWriter.markCursor(); + final int $markCursor = $arrayWriter.cursor(); ${writeArrayToBuffer(ctx, element, et, arrayWriter)} - $arrayWriter.setOffsetAndSizeFromMark($index); + $arrayWriter.setOffsetAndSizeFromMark($index, $markCursor); """ case m @ MapType(kt, vt, _) => s""" - $arrayWriter.markCursor(); + final int $markCursor = $arrayWriter.cursor(); ${writeMapToBuffer(ctx, element, kt, vt, arrayWriter)} - $arrayWriter.setOffsetAndSizeFromMark($index); + $arrayWriter.setOffsetAndSizeFromMark($index, $markCursor); """ case t: DecimalType => From a94d4703d237dd8fcd6958b546bfdd2048b815ca Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 29 Mar 2018 21:18:51 +0100 Subject: [PATCH 08/11] address review comments --- .../expressions/codegen/BufferHolder.java | 4 --- .../codegen/UnsafeArrayWriter.java | 32 ++++++++----------- .../expressions/codegen/UnsafeRowWriter.java | 13 ++++---- .../expressions/codegen/UnsafeWriter.java | 8 ++--- .../InterpretedUnsafeProjection.scala | 17 ++++++---- .../codegen/GenerateUnsafeProjection.scala | 28 ++++++++-------- 6 files changed, 47 insertions(+), 55 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index b97709459fbf..6dbc46eba7b3 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -30,10 +30,6 @@ * this class per writing program, so that the memory segment/data buffer can be reused. Note that * for each incoming record, we should call `reset` of BufferHolder instance before write the record * and reuse the data buffer. - * - * Generally we should call `UnsafeRowWriter.setTotalSize` using `BufferHolder.totalSize` to update - * the size of the result row, after writing a record to the buffer. However, we can skip this step - * if the fields of row are all fixed-length, as the size of result row is also fixed. */ final class BufferHolder { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java index 519c57494f90..f5818879dd00 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -73,18 +73,14 @@ public void initialize(int numElements) { incrementCursor(headerInBytes + fixedPartInBytes); } - protected long getOffset(int ordinal, int elementSize) { - return getElementOffset(ordinal, elementSize); - } - - private long getElementOffset(int ordinal, int elementSize) { + private long getElementOffset(int ordinal) { return startingOffset + headerInBytes + ordinal * elementSize; } @Override - public void setOffsetAndSizeFromMark(int ordinal, int mark) { + public void setOffsetAndSizeFromPreviousCursor(int ordinal, int mark) { assertIndexIsValid(ordinal); - _setOffsetAndSizeFromMark(ordinal, mark); + _setOffsetAndSizeFromPreviousCursor(ordinal, mark); } private void setNullBit(int ordinal) { @@ -95,62 +91,62 @@ private void setNullBit(int ordinal) { public void setNull1Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putByte(buffer(), getElementOffset(ordinal, 1), (byte)0); + Platform.putByte(buffer(), getElementOffset(ordinal), (byte)0); } public void setNull2Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putShort(buffer(), getElementOffset(ordinal, 2), (short)0); + Platform.putShort(buffer(), getElementOffset(ordinal), (short)0); } public void setNull4Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putInt(buffer(), getElementOffset(ordinal, 4), 0); + Platform.putInt(buffer(), getElementOffset(ordinal), 0); } public void setNull8Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putLong(buffer(), getElementOffset(ordinal, 8), (long)0); + Platform.putLong(buffer(), getElementOffset(ordinal), (long)0); } public void setNull(int ordinal) { setNull8Bytes(ordinal); } public void write(int ordinal, boolean value) { assertIndexIsValid(ordinal); - writeBoolean(getElementOffset(ordinal, 1), value); + writeBoolean(getElementOffset(ordinal), value); } public void write(int ordinal, byte value) { assertIndexIsValid(ordinal); - writeByte(getElementOffset(ordinal, 1), value); + writeByte(getElementOffset(ordinal), value); } public void write(int ordinal, short value) { assertIndexIsValid(ordinal); - writeShort(getElementOffset(ordinal, 2), value); + writeShort(getElementOffset(ordinal), value); } public void write(int ordinal, int value) { assertIndexIsValid(ordinal); - writeInt(getElementOffset(ordinal, 4), value); + writeInt(getElementOffset(ordinal), value); } public void write(int ordinal, long value) { assertIndexIsValid(ordinal); - writeLong(getElementOffset(ordinal, 8), value); + writeLong(getElementOffset(ordinal), value); } public void write(int ordinal, float value) { assertIndexIsValid(ordinal); - writeFloat(getElementOffset(ordinal, 4), value); + writeFloat(getElementOffset(ordinal), value); } public void write(int ordinal, double value) { assertIndexIsValid(ordinal); - writeDouble(getElementOffset(ordinal, 8), value); + writeDouble(getElementOffset(ordinal), value); } public void write(int ordinal, Decimal input, int precision, int scale) { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index 1c87e47bfa8e..a1b0cc5486bf 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -34,6 +34,10 @@ * Note that if this is the outermost writer, which means we will always write from the very * beginning of the global row buffer, we don't need to update `startingOffset` and can just call * `zeroOutNullBytes` before writing new data. + * + * Generally we should call `UnsafeRowWriter.setTotalSize` to update the size of the result row, + * after writing a record to the buffer. However, we can skip this step if the fields of row are + * all fixed-length, as the size of result row is also fixed. */ public final class UnsafeRowWriter extends UnsafeWriter { @@ -130,18 +134,13 @@ public void setNull8Bytes(int ordinal) { setNullAt(ordinal); } - @Override - protected final long getOffset(int ordinal, int elementSize) { - return getFieldOffset(ordinal); - } - public long getFieldOffset(int ordinal) { return startingOffset + nullBitsSize + 8 * ordinal; } @Override - public void setOffsetAndSizeFromMark(int ordinal, int mark) { - _setOffsetAndSizeFromMark(ordinal, mark); + public void setOffsetAndSizeFromPreviousCursor(int ordinal, int previousCursor) { + _setOffsetAndSizeFromPreviousCursor(ordinal, previousCursor); } public void write(int ordinal, boolean value) { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java index 3d5ed1e42d14..f2279e8979e6 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -67,10 +67,10 @@ public final void incrementCursor(int val) { holder.incrementCursor(val); } - public abstract void setOffsetAndSizeFromMark(int ordinal, int mark); + public abstract void setOffsetAndSizeFromPreviousCursor(int ordinal, int previousCursor); - protected void _setOffsetAndSizeFromMark(int ordinal, int mark) { - setOffsetAndSize(ordinal, mark, cursor() - mark); + protected void _setOffsetAndSizeFromPreviousCursor(int ordinal, int previousCursor) { + setOffsetAndSize(ordinal, previousCursor, cursor() - previousCursor); } protected void setOffsetAndSize(int ordinal, int size) { @@ -90,8 +90,6 @@ protected final void zeroOutPaddingBytes(int numBytes) { } } - protected abstract long getOffset(int ordinal, int elementSize); - public abstract void setNull1Bytes(int ordinal); public abstract void setNull2Bytes(int ordinal); public abstract void setNull4Bytes(int ordinal); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index f285b18d7ac5..4eeb2a29757b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -174,7 +174,7 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { val rowWriter = new UnsafeRowWriter(writer, numFields) val structWriter = generateStructWriter(rowWriter, fields) (v, i) => { - val markCursor = writer.cursor() + val previousCursor = writer.cursor() v.getStruct(i, fields.length) match { case row: UnsafeRow => writeUnsafeData( @@ -188,7 +188,7 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { rowWriter.resetRowWriter() structWriter.apply(row) } - writer.setOffsetAndSizeFromMark(i, markCursor) + writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } case ArrayType(elementType, containsNull) => @@ -198,9 +198,9 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { elementType, containsNull) (v, i) => { - val markCursor = writer.cursor() + val previousCursor = writer.cursor() writeArray(arrayWriter, elementWriter, v.getArray(i)) - writer.setOffsetAndSizeFromMark(i, markCursor) + writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } case MapType(keyType, valueType, valueContainsNull) => @@ -215,7 +215,7 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { valueType, valueContainsNull) (v, i) => { - val markCursor = writer.cursor() + val previousCursor = writer.cursor() v.getMap(i) match { case map: UnsafeMapData => writeUnsafeData( @@ -231,12 +231,15 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { // Write the keys and write the numBytes of key array into the first 8 bytes. writeArray(keyArrayWriter, keyWriter, map.keyArray()) Platform.putLong( - valueArrayWriter.buffer, markCursor, valueArrayWriter.cursor - markCursor - 8) + valueArrayWriter.buffer, + previousCursor, + valueArrayWriter.cursor - previousCursor - 8 + ) // Write the values. writeArray(valueArrayWriter, valueWriter, map.valueArray()) } - writer.setOffsetAndSizeFromMark(i, markCursor) + writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } case udt: UserDefinedType[_] => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 7e1b6b6f55f9..c14900dac146 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -104,34 +104,34 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s"$rowWriter.write($index, (Decimal) null, ${t.precision}, ${t.scale});" case _ => s"$rowWriter.setNullAt($index);" } - val markCursor = ctx.freshName("markCursor") + val previousCursor = ctx.freshName("previousCursor") val writeField = dt match { case t: StructType => s""" // Remember the current cursor so that we can calculate how many bytes are // written later. - final int $markCursor = $rowWriter.cursor(); + final int $previousCursor = $rowWriter.cursor(); ${writeStructToBuffer(ctx, input.value, t.map(_.dataType), rowWriter)} - $rowWriter.setOffsetAndSizeFromMark($index, $markCursor); + $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """ case a @ ArrayType(et, _) => s""" // Remember the current cursor so that we can calculate how many bytes are // written later. - final int $markCursor = $rowWriter.cursor(); + final int $previousCursor = $rowWriter.cursor(); ${writeArrayToBuffer(ctx, input.value, et, rowWriter)} - $rowWriter.setOffsetAndSizeFromMark($index, $markCursor); + $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """ case m @ MapType(kt, vt, _) => s""" // Remember the current cursor so that we can calculate how many bytes are // written later. - final int $markCursor = $rowWriter.cursor(); + final int $previousCursor = $rowWriter.cursor(); ${writeMapToBuffer(ctx, input.value, kt, vt, rowWriter)} - $rowWriter.setOffsetAndSizeFromMark($index, $markCursor); + $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """ case t: DecimalType => @@ -203,29 +203,29 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val arrayWriterClass = classOf[UnsafeArrayWriter].getName val arrayWriter = ctx.addMutableState(arrayWriterClass, "arrayWriter", v => s"$v = new $arrayWriterClass($rowWriter, $elementOrOffsetSize);") - val markCursor = ctx.freshName("markCursor") + val previousCursor = ctx.freshName("previousCursor") val element = CodeGenerator.getValue(tmpInput, et, index) val writeElement = et match { case t: StructType => s""" - final int $markCursor = $arrayWriter.cursor(); + final int $previousCursor = $arrayWriter.cursor(); ${writeStructToBuffer(ctx, element, t.map(_.dataType), arrayWriter)} - $arrayWriter.setOffsetAndSizeFromMark($index, $markCursor); + $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """ case a @ ArrayType(et, _) => s""" - final int $markCursor = $arrayWriter.cursor(); + final int $previousCursor = $arrayWriter.cursor(); ${writeArrayToBuffer(ctx, element, et, arrayWriter)} - $arrayWriter.setOffsetAndSizeFromMark($index, $markCursor); + $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """ case m @ MapType(kt, vt, _) => s""" - final int $markCursor = $arrayWriter.cursor(); + final int $previousCursor = $arrayWriter.cursor(); ${writeMapToBuffer(ctx, element, kt, vt, arrayWriter)} - $arrayWriter.setOffsetAndSizeFromMark($index, $markCursor); + $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """ case t: DecimalType => From 6caf11cc0244a43e1b01c3b76942eea67a7318f5 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 1 Apr 2018 01:43:14 +0100 Subject: [PATCH 09/11] address review comments --- .../catalyst/expressions/codegen/BufferHolder.java | 2 +- .../expressions/codegen/UnsafeArrayWriter.java | 10 ++-------- .../expressions/codegen/UnsafeRowWriter.java | 9 ++------- .../catalyst/expressions/codegen/UnsafeWriter.java | 14 ++++++-------- .../expressions/InterpretedUnsafeProjection.scala | 4 ++-- .../codegen/GenerateUnsafeProjection.scala | 4 ++-- 6 files changed, 15 insertions(+), 28 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index 6dbc46eba7b3..69573ef9bee3 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -90,7 +90,7 @@ int getCursor() { return cursor; } - void incrementCursor(int val) { + void increaseCursor(int val) { cursor += val; } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java index f5818879dd00..12f939544950 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -70,19 +70,13 @@ public void initialize(int numElements) { for (int i = elementSize * numElements; i < fixedPartInBytes; i++) { Platform.putByte(buffer(), startingOffset + headerInBytes + i, (byte) 0); } - incrementCursor(headerInBytes + fixedPartInBytes); + increaseCursor(headerInBytes + fixedPartInBytes); } private long getElementOffset(int ordinal) { return startingOffset + headerInBytes + ordinal * elementSize; } - @Override - public void setOffsetAndSizeFromPreviousCursor(int ordinal, int mark) { - assertIndexIsValid(ordinal); - _setOffsetAndSizeFromPreviousCursor(ordinal, mark); - } - private void setNullBit(int ordinal) { assertIndexIsValid(ordinal); BitSetMethods.set(buffer(), startingOffset + 8, ordinal); @@ -170,7 +164,7 @@ public void write(int ordinal, Decimal input, int precision, int scale) { setOffsetAndSize(ordinal, numBytes); // move the cursor forward with 8-bytes boundary - incrementCursor(roundedSize); + increaseCursor(roundedSize); } } else { setNull(ordinal); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index a1b0cc5486bf..4cb6822f385d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -91,7 +91,7 @@ public void resetRowWriter() { // grow the global buffer to make sure it has enough space to write fixed-length data. grow(fixedSize); - incrementCursor(fixedSize); + increaseCursor(fixedSize); zeroOutNullBytes(); } @@ -138,11 +138,6 @@ public long getFieldOffset(int ordinal) { return startingOffset + nullBitsSize + 8 * ordinal; } - @Override - public void setOffsetAndSizeFromPreviousCursor(int ordinal, int previousCursor) { - _setOffsetAndSizeFromPreviousCursor(ordinal, previousCursor); - } - public void write(int ordinal, boolean value) { final long offset = getFieldOffset(ordinal); Platform.putLong(buffer(), offset, 0L); @@ -217,7 +212,7 @@ public void write(int ordinal, Decimal input, int precision, int scale) { } // move the cursor forward. - incrementCursor(16); + increaseCursor(16); } } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java index f2279e8979e6..d6d907af8694 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -63,13 +63,11 @@ public final int cursor() { return holder.getCursor(); } - public final void incrementCursor(int val) { - holder.incrementCursor(val); + public final void increaseCursor(int val) { + holder.increaseCursor(val); } - public abstract void setOffsetAndSizeFromPreviousCursor(int ordinal, int previousCursor); - - protected void _setOffsetAndSizeFromPreviousCursor(int ordinal, int previousCursor) { + public final void setOffsetAndSizeFromPreviousCursor(int ordinal, int previousCursor) { setOffsetAndSize(ordinal, previousCursor, cursor() - previousCursor); } @@ -119,7 +117,7 @@ public final void write(int ordinal, UTF8String input) { setOffsetAndSize(ordinal, numBytes); // move the cursor forward. - incrementCursor(roundedSize); + increaseCursor(roundedSize); } public final void write(int ordinal, byte[] input) { @@ -141,7 +139,7 @@ public final void write(int ordinal, byte[] input, int offset, int numBytes) { setOffsetAndSize(ordinal, numBytes); // move the cursor forward. - incrementCursor(roundedSize); + increaseCursor(roundedSize); } public final void write(int ordinal, CalendarInterval input) { @@ -155,7 +153,7 @@ public final void write(int ordinal, CalendarInterval input) { setOffsetAndSize(ordinal, 16); // move the cursor forward. - incrementCursor(16); + increaseCursor(16); } protected final void writeBoolean(long offset, boolean value) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index 4eeb2a29757b..cecbcd9fc73e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -226,7 +226,7 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { case map => // preserve 8 bytes to write the key array numBytes later. valueArrayWriter.grow(8) - valueArrayWriter.incrementCursor(8) + valueArrayWriter.increaseCursor(8) // Write the keys and write the numBytes of key array into the first 8 bytes. writeArray(keyArrayWriter, keyWriter, map.keyArray()) @@ -350,6 +350,6 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { writer.buffer, writer.cursor, sizeInBytes) - writer.incrementCursor(sizeInBytes) + writer.increaseCursor(sizeInBytes) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index c14900dac146..ebab13055733 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -276,7 +276,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } else { // preserve 8 bytes to write the key array numBytes later. $rowWriter.grow(8); - $rowWriter.incrementCursor(8); + $rowWriter.increaseCursor(8); // Remember the current cursor so that we can write numBytes of key array later. final int $tmpCursor = $rowWriter.cursor(); @@ -301,7 +301,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // grow the global buffer before writing data. $rowWriter.grow($sizeInBytes); $input.writeToMemory($rowWriter.buffer(), $rowWriter.cursor()); - $rowWriter.incrementCursor($sizeInBytes); + $rowWriter.increaseCursor($sizeInBytes); """ } From 9dc36b7dbcf74335df244d695ce72d75a74328f3 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 2 Apr 2018 03:42:15 +0100 Subject: [PATCH 10/11] address review comment --- .../expressions/codegen/BufferHolder.java | 2 +- .../codegen/UnsafeArrayWriter.java | 18 ++++++------ .../expressions/codegen/UnsafeRowWriter.java | 28 +++++++++---------- .../expressions/codegen/UnsafeWriter.java | 28 +++++++++---------- .../InterpretedUnsafeProjection.scala | 4 +-- .../codegen/GenerateUnsafeProjection.scala | 4 +-- 6 files changed, 42 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index 69573ef9bee3..537ef244b7e8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -82,7 +82,7 @@ void grow(int neededSize) { } } - byte[] buffer() { + byte[] getBuffer() { return buffer; } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java index 12f939544950..a78dd970d23e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -61,14 +61,14 @@ public void initialize(int numElements) { holder.grow(headerInBytes + fixedPartInBytes); // Write numElements and clear out null bits to header - Platform.putLong(buffer(), startingOffset, numElements); + Platform.putLong(getBuffer(), startingOffset, numElements); for (int i = 8; i < headerInBytes; i += 8) { - Platform.putLong(buffer(), startingOffset + i, 0L); + Platform.putLong(getBuffer(), startingOffset + i, 0L); } // fill 0 into reminder part of 8-bytes alignment in unsafe array for (int i = elementSize * numElements; i < fixedPartInBytes; i++) { - Platform.putByte(buffer(), startingOffset + headerInBytes + i, (byte) 0); + Platform.putByte(getBuffer(), startingOffset + headerInBytes + i, (byte) 0); } increaseCursor(headerInBytes + fixedPartInBytes); } @@ -79,31 +79,31 @@ private long getElementOffset(int ordinal) { private void setNullBit(int ordinal) { assertIndexIsValid(ordinal); - BitSetMethods.set(buffer(), startingOffset + 8, ordinal); + BitSetMethods.set(getBuffer(), startingOffset + 8, ordinal); } public void setNull1Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putByte(buffer(), getElementOffset(ordinal), (byte)0); + writeByte(getElementOffset(ordinal), (byte)0); } public void setNull2Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putShort(buffer(), getElementOffset(ordinal), (short)0); + writeShort(getElementOffset(ordinal), (short)0); } public void setNull4Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putInt(buffer(), getElementOffset(ordinal), 0); + writeInt(getElementOffset(ordinal), 0); } public void setNull8Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putLong(buffer(), getElementOffset(ordinal), (long)0); + writeLong(getElementOffset(ordinal), 0); } public void setNull(int ordinal) { setNull8Bytes(ordinal); } @@ -160,7 +160,7 @@ public void write(int ordinal, Decimal input, int precision, int scale) { // Write the bytes to the variable length portion. Platform.copyMemory( - bytes, Platform.BYTE_ARRAY_OFFSET, buffer(), cursor(), numBytes); + bytes, Platform.BYTE_ARRAY_OFFSET, getBuffer(), cursor(), numBytes); setOffsetAndSize(ordinal, numBytes); // move the cursor forward with 8-bytes boundary diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index 4cb6822f385d..8d88cc95504f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -101,17 +101,17 @@ public void resetRowWriter() { */ public void zeroOutNullBytes() { for (int i = 0; i < nullBitsSize; i += 8) { - Platform.putLong(buffer(), startingOffset + i, 0L); + Platform.putLong(getBuffer(), startingOffset + i, 0L); } } public boolean isNullAt(int ordinal) { - return BitSetMethods.isSet(buffer(), startingOffset, ordinal); + return BitSetMethods.isSet(getBuffer(), startingOffset, ordinal); } public void setNullAt(int ordinal) { - BitSetMethods.set(buffer(), startingOffset, ordinal); - Platform.putLong(buffer(), getFieldOffset(ordinal), 0L); + BitSetMethods.set(getBuffer(), startingOffset, ordinal); + write(ordinal, 0L); } @Override @@ -140,25 +140,25 @@ public long getFieldOffset(int ordinal) { public void write(int ordinal, boolean value) { final long offset = getFieldOffset(ordinal); - Platform.putLong(buffer(), offset, 0L); + writeLong(offset, 0L); writeBoolean(offset, value); } public void write(int ordinal, byte value) { final long offset = getFieldOffset(ordinal); - Platform.putLong(buffer(), offset, 0L); + writeLong(offset, 0L); writeByte(offset, value); } public void write(int ordinal, short value) { final long offset = getFieldOffset(ordinal); - Platform.putLong(buffer(), offset, 0L); + writeLong(offset, 0L); writeShort(offset, value); } public void write(int ordinal, int value) { final long offset = getFieldOffset(ordinal); - Platform.putLong(buffer(), offset, 0L); + writeLong(offset, 0L); writeInt(offset, value); } @@ -168,7 +168,7 @@ public void write(int ordinal, long value) { public void write(int ordinal, float value) { final long offset = getFieldOffset(ordinal); - Platform.putLong(buffer(), offset, 0L); + writeLong(offset, 0); writeFloat(offset, value); } @@ -179,7 +179,7 @@ public void write(int ordinal, double value) { public void write(int ordinal, Decimal input, int precision, int scale) { if (precision <= Decimal.MAX_LONG_DIGITS()) { // make sure Decimal object has the same scale as DecimalType - if (input.changePrecision(precision, scale)) { + if (input != null && input.changePrecision(precision, scale)) { write(ordinal, input.toUnscaledLong()); } else { setNullAt(ordinal); @@ -192,10 +192,10 @@ public void write(int ordinal, Decimal input, int precision, int scale) { // Note that we may pass in null Decimal object to set null for it. if (input == null || !input.changePrecision(precision, scale)) { // zero-out the bytes - Platform.putLong(buffer(), cursor(), 0L); - Platform.putLong(buffer(), cursor() + 8, 0L); + Platform.putLong(getBuffer(), cursor(), 0L); + Platform.putLong(getBuffer(), cursor() + 8, 0L); - BitSetMethods.set(buffer(), startingOffset, ordinal); + BitSetMethods.set(getBuffer(), startingOffset, ordinal); // keep the offset for future update setOffsetAndSize(ordinal, 0); } else { @@ -207,7 +207,7 @@ public void write(int ordinal, Decimal input, int precision, int scale) { // Write the bytes to the variable length portion. Platform.copyMemory( - bytes, Platform.BYTE_ARRAY_OFFSET, buffer(), cursor(), numBytes); + bytes, Platform.BYTE_ARRAY_OFFSET, getBuffer(), cursor(), numBytes); setOffsetAndSize(ordinal, bytes.length); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java index d6d907af8694..de0eb6dbb76b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -43,8 +43,8 @@ public final BufferHolder getBufferHolder() { return holder; } - public final byte[] buffer() { - return holder.buffer(); + public final byte[] getBuffer() { + return holder.getBuffer(); } public final void reset() { @@ -84,7 +84,7 @@ protected void setOffsetAndSize(int ordinal, int currentCursor, int size) { protected final void zeroOutPaddingBytes(int numBytes) { if ((numBytes & 0x07) > 0) { - Platform.putLong(buffer(), cursor() + ((numBytes >> 3) << 3), 0L); + Platform.putLong(getBuffer(), cursor() + ((numBytes >> 3) << 3), 0L); } } @@ -112,7 +112,7 @@ public final void write(int ordinal, UTF8String input) { zeroOutPaddingBytes(numBytes); // Write the bytes to the variable length portion. - input.writeToMemory(buffer(), cursor()); + input.writeToMemory(getBuffer(), cursor()); setOffsetAndSize(ordinal, numBytes); @@ -134,7 +134,7 @@ public final void write(int ordinal, byte[] input, int offset, int numBytes) { // Write the bytes to the variable length portion. Platform.copyMemory( - input, Platform.BYTE_ARRAY_OFFSET + offset, buffer(), cursor(), numBytes); + input, Platform.BYTE_ARRAY_OFFSET + offset, getBuffer(), cursor(), numBytes); setOffsetAndSize(ordinal, numBytes); @@ -147,8 +147,8 @@ public final void write(int ordinal, CalendarInterval input) { grow(16); // Write the months and microseconds fields of Interval to the variable length portion. - Platform.putLong(buffer(), cursor(), input.months); - Platform.putLong(buffer(), cursor() + 8, input.microseconds); + Platform.putLong(getBuffer(), cursor(), input.months); + Platform.putLong(getBuffer(), cursor() + 8, input.microseconds); setOffsetAndSize(ordinal, 16); @@ -157,36 +157,36 @@ public final void write(int ordinal, CalendarInterval input) { } protected final void writeBoolean(long offset, boolean value) { - Platform.putBoolean(buffer(), offset, value); + Platform.putBoolean(getBuffer(), offset, value); } protected final void writeByte(long offset, byte value) { - Platform.putByte(buffer(), offset, value); + Platform.putByte(getBuffer(), offset, value); } protected final void writeShort(long offset, short value) { - Platform.putShort(buffer(), offset, value); + Platform.putShort(getBuffer(), offset, value); } protected final void writeInt(long offset, int value) { - Platform.putInt(buffer(), offset, value); + Platform.putInt(getBuffer(), offset, value); } protected final void writeLong(long offset, long value) { - Platform.putLong(buffer(), offset, value); + Platform.putLong(getBuffer(), offset, value); } protected final void writeFloat(long offset, float value) { if (Float.isNaN(value)) { value = Float.NaN; } - Platform.putFloat(buffer(), offset, value); + Platform.putFloat(getBuffer(), offset, value); } protected final void writeDouble(long offset, double value) { if (Double.isNaN(value)) { value = Double.NaN; } - Platform.putDouble(buffer(), offset, value); + Platform.putDouble(getBuffer(), offset, value); } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index cecbcd9fc73e..847fc9685a73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -231,7 +231,7 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { // Write the keys and write the numBytes of key array into the first 8 bytes. writeArray(keyArrayWriter, keyWriter, map.keyArray()) Platform.putLong( - valueArrayWriter.buffer, + valueArrayWriter.getBuffer, previousCursor, valueArrayWriter.cursor - previousCursor - 8 ) @@ -347,7 +347,7 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { Platform.copyMemory( baseObject, baseOffset, - writer.buffer, + writer.getBuffer, writer.cursor, sizeInBytes) writer.increaseCursor(sizeInBytes) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index ebab13055733..231d4d17f83a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -283,7 +283,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, rowWriter)} // Write the numBytes of key array into the first 8 bytes. - Platform.putLong($rowWriter.buffer(), $tmpCursor - 8, $rowWriter.cursor() - $tmpCursor); + Platform.putLong($rowWriter.getBuffer(), $tmpCursor - 8, $rowWriter.cursor() - $tmpCursor); ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, rowWriter)} } @@ -300,7 +300,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro final int $sizeInBytes = $input.getSizeInBytes(); // grow the global buffer before writing data. $rowWriter.grow($sizeInBytes); - $input.writeToMemory($rowWriter.buffer(), $rowWriter.cursor()); + $input.writeToMemory($rowWriter.getBuffer(), $rowWriter.cursor()); $rowWriter.increaseCursor($sizeInBytes); """ } From 209da248279a12fe75dd3e58156c4bbb50576086 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 2 Apr 2018 10:08:00 +0100 Subject: [PATCH 11/11] address review comment --- .../kafka010/KafkaRecordToUnsafeRowConverter.scala | 1 - .../expressions/codegen/UnsafeRowWriter.java | 13 +++++-------- .../expressions/InterpretedUnsafeProjection.scala | 1 - .../codegen/GenerateUnsafeProjection.scala | 14 +------------- .../expressions/RowBasedKeyValueBatchSuite.java | 3 --- .../aggregate/RowBasedHashMapGenerator.scala | 1 - .../columnar/GenerateColumnAccessor.scala | 1 - .../datasources/text/TextFileFormat.scala | 1 - 8 files changed, 6 insertions(+), 29 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala index d99e7a7e57d6..f35a143e0037 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala @@ -44,7 +44,6 @@ private[kafka010] class KafkaRecordToUnsafeRowConverter { 5, DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(record.timestamp))) rowWriter.write(6, record.timestampType.id) - rowWriter.setTotalSize() rowWriter.getRow() } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index 8d88cc95504f..71c49d8ed017 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -34,10 +34,6 @@ * Note that if this is the outermost writer, which means we will always write from the very * beginning of the global row buffer, we don't need to update `startingOffset` and can just call * `zeroOutNullBytes` before writing new data. - * - * Generally we should call `UnsafeRowWriter.setTotalSize` to update the size of the result row, - * after writing a record to the buffer. However, we can skip this step if the fields of row are - * all fixed-length, as the size of result row is also fixed. */ public final class UnsafeRowWriter extends UnsafeWriter { @@ -74,12 +70,13 @@ private UnsafeRowWriter(UnsafeRow row, BufferHolder holder, int numFields) { this.startingOffset = cursor(); } + /** + * Updates total size of the UnsafeRow using the size collected by BufferHolder, and returns + * the UnsafeRow created at a constructor + */ public UnsafeRow getRow() { - return row; - } - - public void setTotalSize() { row.setTotalSize(totalSize()); + return row; } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index 847fc9685a73..b31466f5c92d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -80,7 +80,6 @@ class InterpretedUnsafeProjection(expressions: Array[Expression]) extends Unsafe // Write the intermediate row to an unsafe row. rowWriter.reset() writer(intermediate) - rowWriter.setTotalSize() rowWriter.getRow() } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 231d4d17f83a..ab2254cd9f70 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -322,17 +322,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val rowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", v => s"$v = new $rowWriterClass(${expressions.length}, ${numVarLenFields * 32});") - val resetBufferHolder = if (numVarLenFields == 0) { - "" - } else { - s"$rowWriter.reset();" - } - val updateRowSize = if (numVarLenFields == 0) { - "" - } else { - s"$rowWriter.setTotalSize();" - } - // Evaluate all the subexpression. val evalSubexpr = ctx.subexprFunctions.mkString("\n") @@ -341,10 +330,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val code = s""" - $resetBufferHolder + $rowWriter.reset(); $evalSubexpr $writeExpressions - $updateRowSize """ ExprCode(code, "false", s"$rowWriter.getRow()") } diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java index 66866880d41b..2da87113c622 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java @@ -58,7 +58,6 @@ private UnsafeRow makeKeyRow(long k1, String k2) { writer.reset(); writer.write(0, k1); writer.write(1, UTF8String.fromString(k2)); - writer.setTotalSize(); return writer.getRow(); } @@ -67,7 +66,6 @@ private UnsafeRow makeKeyRow(long k1, long k2) { writer.reset(); writer.write(0, k1); writer.write(1, k2); - writer.setTotalSize(); return writer.getRow(); } @@ -76,7 +74,6 @@ private UnsafeRow makeValueRow(long v1, long v2) { writer.reset(); writer.write(0, v1); writer.write(1, v2); - writer.setTotalSize(); return writer.getRow(); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala index 3990f7bd7688..d5508275c48c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala @@ -171,7 +171,6 @@ class RowBasedHashMapGenerator( | agg_rowWriter.reset(); //TODO: investigate if reset or zeroout are actually needed | agg_rowWriter.zeroOutNullBytes(); | ${createUnsafeRowForKey}; - | agg_rowWriter.setTotalSize(); | org.apache.spark.sql.catalyst.expressions.UnsafeRow agg_result | = agg_rowWriter.getRow(); | Object kbase = agg_result.getBaseObject(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index 6bbe19e702b6..2d699e8a9d08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -213,7 +213,6 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera rowWriter.reset(); rowWriter.zeroOutNullBytes(); ${extractorCalls} - rowWriter.setTotalSize(); return rowWriter.getRow(); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index f9d833eb5d32..d1ccb92ef8c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -139,7 +139,6 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { // Writes to an UnsafeRow directly unsafeRowWriter.reset() unsafeRowWriter.write(0, line.getBytes, 0, line.getLength) - unsafeRowWriter.setTotalSize() unsafeRowWriter.getRow() } }