diff --git a/core/src/main/java/org/apache/spark/memory/MemoryMode.java b/core/src/main/java/org/apache/spark/memory/MemoryMode.java index 3a5e72d8aaec..ca26a77ee9f6 100644 --- a/core/src/main/java/org/apache/spark/memory/MemoryMode.java +++ b/core/src/main/java/org/apache/spark/memory/MemoryMode.java @@ -22,5 +22,6 @@ @Private public enum MemoryMode { ON_HEAP, - OFF_HEAP + OFF_HEAP, + ON_HEAP_UNSAFE } diff --git a/core/src/main/scala/org/apache/spark/util/Benchmark.scala b/core/src/main/scala/org/apache/spark/util/Benchmark.scala index 7def44bd2a2b..7d900938204c 100644 --- a/core/src/main/scala/org/apache/spark/util/Benchmark.scala +++ b/core/src/main/scala/org/apache/spark/util/Benchmark.scala @@ -69,12 +69,17 @@ private[spark] class Benchmark( * @param name of the benchmark case * @param numIters if non-zero, forces exactly this many iterations to be run */ - def addCase(name: String, numIters: Int = 0)(f: Int => Unit): Unit = { - addTimerCase(name, numIters) { timer => + def addCase( + name: String, + numIters: Int = 0, + prepare: () => Unit = () => { }, + cleanup: () => Unit = () => { })(f: Int => Unit): Unit = { + val timedF = (timer: Benchmark.Timer) => { timer.startTiming() f(timer.iteration) timer.stopTiming() } + benchmarks += Benchmark.Case(name, timedF, numIters, prepare, cleanup) } /** @@ -101,7 +106,7 @@ private[spark] class Benchmark( val results = benchmarks.map { c => println(" Running case: " + c.name) - measure(valuesPerIteration, c.numIters)(c.fn) + measure(valuesPerIteration, c.numIters, c.prepare, c.cleanup)(c.fn) } println @@ -128,21 +133,33 @@ private[spark] class Benchmark( * Runs a single function `f` for iters, returning the average time the function took and * the rate of the function. */ - def measure(num: Long, overrideNumIters: Int)(f: Timer => Unit): Result = { + def measure(num: Long, overrideNumIters: Int, prepare: () => Unit, cleanup: () => Unit) + (f: Timer => Unit): Result = { System.gc() // ensures garbage from previous cases don't impact this one val warmupDeadline = warmupTime.fromNow while (!warmupDeadline.isOverdue) { - f(new Benchmark.Timer(-1)) + try { + prepare() + f(new Benchmark.Timer(-1)) + } finally { + cleanup() + } } val minIters = if (overrideNumIters != 0) overrideNumIters else minNumIters val minDuration = if (overrideNumIters != 0) 0 else minTime.toNanos val runTimes = ArrayBuffer[Long]() var i = 0 while (i < minIters || runTimes.sum < minDuration) { - val timer = new Benchmark.Timer(i) - f(timer) - val runTime = timer.totalTime() - runTimes += runTime + val runTime = try { + prepare() + val timer = new Benchmark.Timer(i) + f(timer) + val time = timer.totalTime() + runTimes += time + time + } finally { + cleanup() + } if (outputPerIteration) { // scalastyle:off @@ -188,7 +205,12 @@ private[spark] object Benchmark { } } - case class Case(name: String, fn: Timer => Unit, numIters: Int) + case class Case( + name: String, + fn: Timer => Unit, + numIters: Int, + prepare: () => Unit = () => { }, + cleanup: () => Unit = () => { }) case class Result(avgMs: Double, bestRate: Double, bestMs: Double) /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index 354c878aca00..9c0082e15635 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.execution.vectorized; +import java.io.Serializable; import java.math.BigDecimal; import java.math.BigInteger; @@ -25,10 +26,14 @@ import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData; +import org.apache.spark.sql.catalyst.expressions.UnsafeMapData; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.internal.SQLConf; import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -57,7 +62,9 @@ * * ColumnVectors are intended to be reused. */ -public abstract class ColumnVector implements AutoCloseable { +public abstract class ColumnVector implements AutoCloseable, Serializable { + ColumnVector() { } + /** * Allocates a column to store elements of `type` on or off heap. * Capacity is the initial capacity of the vector and it will grow as necessary. Capacity is @@ -66,6 +73,8 @@ public abstract class ColumnVector implements AutoCloseable { public static ColumnVector allocate(int capacity, DataType type, MemoryMode mode) { if (mode == MemoryMode.OFF_HEAP) { return new OffHeapColumnVector(capacity, type); + } else if (mode == MemoryMode.ON_HEAP_UNSAFE) { + return new OnHeapUnsafeColumnVector(capacity, type); } else { return new OnHeapColumnVector(capacity, type); } @@ -548,18 +557,69 @@ public ColumnarBatch.Row getStruct(int rowId) { * Returns a utility object to get structs. * provided to keep API compatibility with InternalRow for code generation */ - public ColumnarBatch.Row getStruct(int rowId, int size) { - resultStruct.rowId = rowId; - return resultStruct; + public InternalRow getStruct(int rowId, int size) { + if (!unsafeDirectCopy) { + resultStruct.rowId = rowId; + return resultStruct; + } + resultArray.data.loadBytes(resultArray); + int offset = getArrayOffset(rowId); + int length = getArrayLength(rowId); + UnsafeRow map = new UnsafeRow(size); + map.pointTo(resultArray.byteArray, Platform.BYTE_ARRAY_OFFSET + offset, length); + return map; + } + + public int putStruct(int rowId, InternalRow row) { + if (!unsafeDirectCopy) { + throw new UnsupportedOperationException(); + } + assert(row instanceof UnsafeRow); + UnsafeRow unsafeRow = (UnsafeRow)row; + byte[] value = (byte[])unsafeRow.getBaseObject(); + long offset = unsafeRow.getBaseOffset() - Platform.BYTE_ARRAY_OFFSET; + int length = unsafeRow.getSizeInBytes(); + if (offset > Integer.MAX_VALUE) { + throw new UnsupportedOperationException("Cannot put this map to ColumnVector as " + + "it's too big."); + } + putByteArray(rowId, value, (int)offset, length); + return length; } /** * Returns the array at rowid. */ - public final Array getArray(int rowId) { - resultArray.length = getArrayLength(rowId); - resultArray.offset = getArrayOffset(rowId); - return resultArray; + public final ArrayData getArray(int rowId) { + if (unsafeDirectCopy) { + resultArray.data.loadBytes(resultArray); // update resultArray.byteData + int offset = getArrayOffset(rowId); + int length = getArrayLength(rowId); + UnsafeArrayData array = new UnsafeArrayData(); + array.pointTo(resultArray.byteArray, Platform.BYTE_ARRAY_OFFSET + offset, length); + return array; + } else { + resultArray.length = getArrayLength(rowId); + resultArray.offset = getArrayOffset(rowId); + return resultArray; + } + } + + public final int putArray(int rowId, ArrayData array) { + if (!unsafeDirectCopy) { + throw new UnsupportedOperationException(); + } + assert(array instanceof UnsafeArrayData); + UnsafeArrayData unsafeArray = (UnsafeArrayData)array; + byte[] value = (byte[])unsafeArray.getBaseObject(); + long offset = unsafeArray.getBaseOffset() - Platform.BYTE_ARRAY_OFFSET; + int length = unsafeArray.getSizeInBytes(); + if (offset > Integer.MAX_VALUE) { + throw new UnsupportedOperationException("Cannot put this array to ColumnVector as " + + "it's too big."); + } + putByteArray(rowId, value, (int)offset, length); + return length; } /** @@ -579,7 +639,9 @@ public final int putByteArray(int rowId, byte[] value) { * Returns the value for rowId. */ private Array getByteArray(int rowId) { - Array array = getArray(rowId); + resultArray.length = getArrayLength(rowId); + resultArray.offset = getArrayOffset(rowId); + Array array = resultArray; array.data.loadBytes(array); return array; } @@ -587,8 +649,33 @@ private Array getByteArray(int rowId) { /** * Returns the value for rowId. */ - public MapData getMap(int ordinal) { - throw new UnsupportedOperationException(); + public MapData getMap(int rowId) { + if (!unsafeDirectCopy) { + throw new UnsupportedOperationException(); + } + resultArray.data.loadBytes(resultArray); + int offset = getArrayOffset(rowId); + int length = getArrayLength(rowId); + UnsafeMapData map = new UnsafeMapData(); + map.pointTo(resultArray.byteArray, Platform.BYTE_ARRAY_OFFSET + offset, length); + return map; + } + + public int putMap(int rowId, MapData map) { + if (!unsafeDirectCopy) { + throw new UnsupportedOperationException(); + } + assert(map instanceof UnsafeMapData); + UnsafeMapData unsafeMap = (UnsafeMapData)map; + byte[] value = (byte[])unsafeMap.getBaseObject(); + long offset = unsafeMap.getBaseOffset() - Platform.BYTE_ARRAY_OFFSET; + int length = unsafeMap.getSizeInBytes(); + if (offset > Integer.MAX_VALUE) { + throw new UnsupportedOperationException("Cannot put this map to ColumnVector as " + + "it's too big."); + } + putByteArray(rowId, value, (int)offset, length); + return length; } /** @@ -609,14 +696,18 @@ public final Decimal getDecimal(int rowId, int precision, int scale) { } - public final void putDecimal(int rowId, Decimal value, int precision) { + public final int putDecimal(int rowId, Decimal value, int precision) { if (precision <= Decimal.MAX_INT_DIGITS()) { putInt(rowId, (int) value.toUnscaledLong()); + return 4; } else if (precision <= Decimal.MAX_LONG_DIGITS()) { putLong(rowId, value.toUnscaledLong()); + return 8; } else { BigInteger bigInteger = value.toJavaBigDecimal().unscaledValue(); - putByteArray(rowId, bigInteger.toByteArray()); + byte[] array = bigInteger.toByteArray(); + putByteArray(rowId, array); + return array.length; } } @@ -633,6 +724,13 @@ public final UTF8String getUTF8String(int rowId) { } } + public final int putUTF8String(int rowId, UTF8String string) { + assert(dictionary == null); + byte[] array = string.getBytes(); + putByteArray(rowId, array); + return array.length; + } + /** * Returns the byte array for rowId. */ @@ -648,6 +746,11 @@ public final byte[] getBinary(int rowId) { } } + public final int putBinary(int rowId, byte[] bytes) { + putByteArray(rowId, bytes); + return bytes.length; + } + /** * Append APIs. These APIs all behave similarly and will append data to the current vector. It * is not valid to mix the put and append APIs. The append APIs are slower and should only be @@ -894,10 +997,12 @@ public final int appendStruct(boolean isNull) { @VisibleForTesting protected int MAX_CAPACITY = Integer.MAX_VALUE; + protected boolean unsafeDirectCopy; + /** * Data type for this column. */ - protected final DataType type; + protected DataType type; /** * Number of nulls in this column. This is an optimization for the reader, to skip NULL checks. @@ -929,17 +1034,17 @@ public final int appendStruct(boolean isNull) { /** * If this is a nested type (array or struct), the column for the child data. */ - protected final ColumnVector[] childColumns; + protected ColumnVector[] childColumns; /** * Reusable Array holder for getArray(). */ - protected final Array resultArray; + protected Array resultArray; /** * Reusable Struct holder for getStruct(). */ - protected final ColumnarBatch.Row resultStruct; + protected ColumnarBatch.Row resultStruct; /** * The Dictionary for this column. @@ -991,14 +1096,20 @@ public ColumnVector getDictionaryIds() { * type. */ protected ColumnVector(int capacity, DataType type, MemoryMode memMode) { + this(capacity, type, memMode, false); + } + + protected ColumnVector(int capacity, DataType type, MemoryMode memMode, boolean unsafeDirectCopy) { this.capacity = capacity; this.type = type; + this.unsafeDirectCopy = unsafeDirectCopy; if (type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType - || DecimalType.isByteArrayDecimalType(type)) { + || DecimalType.isByteArrayDecimalType(type) + || unsafeDirectCopy && (type instanceof MapType || type instanceof StructType)) { DataType childType; int childCapacity = capacity; - if (type instanceof ArrayType) { + if (!unsafeDirectCopy && type instanceof ArrayType) { childType = ((ArrayType)type).elementType(); } else { childType = DataTypes.ByteType; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index a6ce4c2edc23..d9ec7145622a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.execution.vectorized; +import java.io.Serializable; import java.math.BigDecimal; import java.util.*; @@ -43,7 +44,7 @@ * - There are many TODOs for the existing APIs. They should throw a not implemented exception. * - Compaction: The batch and columns should be able to compact based on a selection vector. */ -public final class ColumnarBatch { +public final class ColumnarBatch implements Serializable { private static final int DEFAULT_BATCH_SIZE = 4 * 1024; private static MemoryMode DEFAULT_MEMORY_MODE = MemoryMode.ON_HEAP; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapUnsafeColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapUnsafeColumnVector.java new file mode 100644 index 000000000000..5c9f40e0cfd5 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapUnsafeColumnVector.java @@ -0,0 +1,478 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.vectorized; + +import java.io.*; + +import org.apache.commons.io.IOUtils; +import org.apache.spark.SparkConf; +import org.apache.spark.io.CompressionCodec; +import org.apache.spark.io.CompressionCodec$; +import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.Platform; + +/** + * A column backed by an in memory JVM array. But, all of data types are stored into byte[]. + * This stores the NULLs as a byte per value and a java array for the values. + */ +public final class OnHeapUnsafeColumnVector extends ColumnVector implements Serializable { + + // The data stored in these arrays need to maintain binary compatible. We can + // directly pass this buffer to external components. + + // This is faster than a boolean array and we optimize this over memory footprint. + private byte[] nulls; + private byte[] compressedNulls; + + // Array for all types + private byte[] data; + private byte[] compressedData; + + // Only set if type is Array. + private int[] arrayLengths; + private int[] arrayOffsets; + + private boolean compressed; + private transient CompressionCodec codec = null; + + OnHeapUnsafeColumnVector() { } + + protected OnHeapUnsafeColumnVector(int capacity, DataType type) { + super(capacity, type, MemoryMode.ON_HEAP, true); + reserveInternal(capacity); + reset(); + } + + @Override + public long valuesNativeAddress() { + throw new RuntimeException("Cannot get native address for on heap column"); + } + @Override + public long nullsNativeAddress() { + throw new RuntimeException("Cannot get native address for on heap column"); + } + + @Override + public void close() { + } + + public void compress(SparkConf conf) { + if (compressed) return; + if (codec == null) { + String codecName = conf.get(SQLConf.CACHE_COMPRESSION_CODEC()); + codec = CompressionCodec$.MODULE$.createCodec(conf, codecName); + } + ByteArrayOutputStream bos; + OutputStream out; + + if (data != null) { + bos = new ByteArrayOutputStream(); + out = codec.compressedOutputStream(bos); + try { + try { + out.write(data); + } finally { + out.close(); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + if (bos.size() < data.length) { + compressedData = bos.toByteArray(); + data = null; + } + } + + if (nulls != null) { + bos = new ByteArrayOutputStream(); + out = codec.compressedOutputStream(bos); + try { + try { + out.write(nulls); + } finally { + out.close(); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + if (bos.size() < nulls.length) { + compressedNulls = bos.toByteArray(); + nulls = null; + } + } + compressed = (compressedData != null) || (compressedNulls != null); + } + + public void decompress(SparkConf conf) { + if (!compressed) return; + if (codec == null) { + String codecName = conf.get(SQLConf.CACHE_COMPRESSION_CODEC()); + codec = CompressionCodec$.MODULE$.createCodec(conf, codecName); + } + ByteArrayInputStream bis; + InputStream in; + + if (compressedData != null) { + bis = new ByteArrayInputStream(compressedData); + in = codec.compressedInputStream(bis); + try { + try { + data = IOUtils.toByteArray(in); + } finally { + in.close(); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + compressedData = null; + } + + if (compressedNulls != null) { + bis = new ByteArrayInputStream(compressedNulls); + in = codec.compressedInputStream(bis); + try { + try { + nulls = IOUtils.toByteArray(in); + } finally { + in.close(); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + compressedNulls = null; + } + compressed = false; + } + + // + // APIs dealing with nulls + // + + @Override + public void putNotNull(int rowId) { + Platform.putByte(nulls, Platform.BYTE_ARRAY_OFFSET + rowId, (byte)0); + } + + @Override + public void putNull(int rowId) { + Platform.putByte(nulls, Platform.BYTE_ARRAY_OFFSET + rowId, (byte)1); + ++numNulls; + anyNullsSet = true; + } + + @Override + public void putNulls(int rowId, int count) { + throw new UnsupportedOperationException(); + } + + @Override + public void putNotNulls(int rowId, int count) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isNullAt(int rowId) { + return Platform.getByte(nulls, Platform.BYTE_ARRAY_OFFSET + rowId) == 1; + } + + // + // APIs dealing with Booleans + // + + @Override + public void putBoolean(int rowId, boolean value) { + Platform.putBoolean(data, Platform.BYTE_ARRAY_OFFSET + rowId, value); + } + + @Override + public void putBooleans(int rowId, int count, boolean value) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean getBoolean(int rowId) { + return Platform.getBoolean(data, Platform.BYTE_ARRAY_OFFSET + rowId); + } + + // + + // + // APIs dealing with Bytes + // + + @Override + public void putByte(int rowId, byte value) { + Platform.putByte(data, Platform.BYTE_ARRAY_OFFSET + rowId, value); + } + + @Override + public void putBytes(int rowId, int count, byte value) { + for (int i = 0; i < count; ++i) { + Platform.putByte(data, Platform.BYTE_ARRAY_OFFSET + rowId + i, value); + } + } + + @Override + public void putBytes(int rowId, int count, byte[] src, int srcIndex) { + System.arraycopy(src, srcIndex, data, rowId, count); + } + + @Override + public byte getByte(int rowId) { + return Platform.getByte(data, Platform.BYTE_ARRAY_OFFSET + rowId); + } + + // + // APIs dealing with Shorts + // + + @Override + public void putShort(int rowId, short value) { + Platform.putShort(data, Platform.BYTE_ARRAY_OFFSET + rowId * 2, value); + } + + @Override + public void putShorts(int rowId, int count, short value) { + throw new UnsupportedOperationException(); + } + + @Override + public void putShorts(int rowId, int count, short[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public short getShort(int rowId) { + return Platform.getShort(data, Platform.BYTE_ARRAY_OFFSET + rowId * 2); + } + + + // + // APIs dealing with Ints + // + + @Override + public void putInt(int rowId, int value) { + Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET + rowId * 4, value); + } + + @Override + public void putInts(int rowId, int count, int value) { + throw new UnsupportedOperationException(); + } + + @Override + public void putInts(int rowId, int count, int[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public int getInt(int rowId) { + return Platform.getInt(data, Platform.BYTE_ARRAY_OFFSET + rowId * 4); + } + + /** + * Returns the dictionary Id for rowId. + * This should only be called when the ColumnVector is dictionaryIds. + * We have this separate method for dictionaryIds as per SPARK-16928. + */ + public int getDictId(int rowId) { throw new UnsupportedOperationException(); } + + // + // APIs dealing with Longs + // + + @Override + public void putLong(int rowId, long value) { + Platform.putLong(data, Platform.BYTE_ARRAY_OFFSET + rowId * 8, value); + } + + @Override + public void putLongs(int rowId, int count, long value) { + throw new UnsupportedOperationException(); + } + + @Override + public void putLongs(int rowId, int count, long[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public long getLong(int rowId) { + return Platform.getLong(data, Platform.BYTE_ARRAY_OFFSET + rowId * 8); + } + + // + // APIs dealing with floats + // + + @Override + public void putFloat(int rowId, float value) { + Platform.putFloat(data, Platform.BYTE_ARRAY_OFFSET + rowId * 4, value); + } + + @Override + public void putFloats(int rowId, int count, float value) { + throw new UnsupportedOperationException(); + } + + @Override + public void putFloats(int rowId, int count, float[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public void putFloats(int rowId, int count, byte[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public float getFloat(int rowId) { + return Platform.getFloat(data, Platform.BYTE_ARRAY_OFFSET + rowId * 4); + } + + // + // APIs dealing with doubles + // + + @Override + public void putDouble(int rowId, double value) { + Platform.putDouble(data, Platform.BYTE_ARRAY_OFFSET + rowId * 8, value); + } + + @Override + public void putDoubles(int rowId, int count, double value) { + throw new UnsupportedOperationException(); + } + + @Override + public void putDoubles(int rowId, int count, double[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public double getDouble(int rowId) { + return Platform.getDouble(data, Platform.BYTE_ARRAY_OFFSET + rowId * 8); + } + + // + // APIs dealing with Arrays + // + + @Override + public int getArrayLength(int rowId) { + return arrayLengths[rowId]; + } + @Override + public int getArrayOffset(int rowId) { + return arrayOffsets[rowId]; + } + + @Override + public void putArray(int rowId, int offset, int length) { + arrayOffsets[rowId] = offset; + arrayLengths[rowId] = length; + } + + @Override + public void loadBytes(ColumnVector.Array array) { + array.byteArray = data; + array.byteArrayOffset = array.offset; + } + + // + // APIs dealing with Byte Arrays + // + + @Override + public int putByteArray(int rowId, byte[] value, int offset, int length) { + int result = arrayData().appendBytes(length, value, offset); + arrayOffsets[rowId] = result; + arrayLengths[rowId] = length; + return result; + } + + // Spilt this function out since it is the slow path. + @Override + protected void reserveInternal(int newCapacity) { + int factor = 0; + if (this.resultArray != null || DecimalType.isByteArrayDecimalType(type)) { + int[] newLengths = new int[newCapacity]; + int[] newOffsets = new int[newCapacity]; + if (this.arrayLengths != null) { + System.arraycopy(this.arrayLengths, 0, newLengths, 0, elementsAppended); + System.arraycopy(this.arrayOffsets, 0, newOffsets, 0, elementsAppended); + } + arrayLengths = newLengths; + arrayOffsets = newOffsets; + factor = -1; + } else if (resultStruct != null || type instanceof NullType) { + // Nothing to store. + factor = -1; + } else if (type instanceof BooleanType) { + factor = 1; + } else if (type instanceof ByteType) { + factor = 1; + } else if (type instanceof ShortType) { + factor = 2; + } else if (type instanceof IntegerType || type instanceof DateType || + DecimalType.is32BitDecimalType(type)) { + factor = 4; + } else if (type instanceof LongType || type instanceof TimestampType || + DecimalType.is64BitDecimalType(type)) { + factor = 8; + } else if (type instanceof FloatType) { + factor = 4; + } else if (type instanceof DoubleType) { + factor = 8; + } + if (factor > 0) { + if (data == null || capacity < newCapacity) { + byte[] newData = new byte[newCapacity * factor]; + if (data != null) + Platform.copyMemory(data, Platform.BYTE_ARRAY_OFFSET, + newData, Platform.BYTE_ARRAY_OFFSET, elementsAppended * factor); + data = newData; + } + } else if (factor == 0) { + throw new RuntimeException("Unhandled " + type); + } + + byte[] newNulls = new byte[newCapacity]; + if (nulls != null) System.arraycopy(nulls, 0, newNulls, 0, elementsAppended); + nulls = newNulls; + + capacity = newCapacity; + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 04fba17be4bf..05c861e353da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -33,10 +33,15 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { val inMemoryTableScan: InMemoryTableScanExec = null + val columnIndexes: Array[Int] = null + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) + lazy val enableScanStatistics: Boolean = + sqlContext.getConf("spark.sql.inMemoryTableScanStatistics.enable", "false").toBoolean + /** * Generate [[ColumnVector]] expressions for our parent to consume as rows. * This is called once per [[ColumnarBatch]]. @@ -78,6 +83,17 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { val scanTimeMetric = metricTerm(ctx, "scanTime") val scanTimeTotalNs = ctx.freshName("scanTime") ctx.addMutableState("long", scanTimeTotalNs, s"$scanTimeTotalNs = 0;") + val incReadBatches = if (!enableScanStatistics) "" else { + val readPartitions = ctx.addReferenceObj("readPartitions", inMemoryTableScan.readPartitions) + val readBatches = ctx.addReferenceObj("readBatches", inMemoryTableScan.readBatches) + ctx.addMutableState("int", "initializeInMemoryTableScanStatistics", + s""" + |$readPartitions.setValue(0); + |$readBatches.setValue(0); + |if ($input.hasNext()) { $readPartitions.add(1); } + """.stripMargin) + s"$readBatches.add(1);" + } val columnarBatchClz = "org.apache.spark.sql.execution.vectorized.ColumnarBatch" val batch = ctx.freshName("batch") @@ -89,7 +105,8 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { val colVars = output.indices.map(i => ctx.freshName("colInstance" + i)) val columnAssigns = colVars.zipWithIndex.map { case (name, i) => ctx.addMutableState(columnVectorClz, name, s"$name = null;") - s"$name = $batch.column($i);" + val index = if (columnIndexes == null) i else columnIndexes(i) + s"$name = $batch.column($index);" } val nextBatch = ctx.freshName("nextBatch") @@ -99,6 +116,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { | long getBatchStart = System.nanoTime(); | if ($input.hasNext()) { | $batch = ($columnarBatchClz)$input.next(); + | $incReadBatches | $numOutputRows.add($batch.numRows()); | $idx = 0; | ${columnAssigns.mkString("", "\n", "\n")} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala index 470307bd940a..3197d0c0358e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala @@ -62,11 +62,20 @@ private[columnar] sealed trait ColumnStats extends Serializable { count += 1 } + def gatherNullStats(): Unit = { + nullCount += 1 + // 1 bytes for null position + sizeInBytes += 1 + count += 1 + } + /** * Column statistics represented as a single row, currently including closed lower bound, closed * upper bound and null count. */ def collectedStatistics: GenericInternalRow + + def collectedStats: Array[Any] } /** @@ -75,6 +84,8 @@ private[columnar] sealed trait ColumnStats extends Serializable { private[columnar] class NoopColumnStats extends ColumnStats { override def gatherStats(row: InternalRow, ordinal: Int): Unit = super.gatherStats(row, ordinal) + override def collectedStats: Array[Any] = Array[Any](null, null, nullCount, count, 0L) + override def collectedStatistics: GenericInternalRow = new GenericInternalRow(Array[Any](null, null, nullCount, count, 0L)) } @@ -93,6 +104,15 @@ private[columnar] class BooleanColumnStats extends ColumnStats { } } + def gatherValueStats(value: Boolean): Unit = { + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += BOOLEAN.defaultSize + count += 1 + } + + override def collectedStats: Array[Any] = Array[Any](lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } @@ -111,6 +131,15 @@ private[columnar] class ByteColumnStats extends ColumnStats { } } + def gatherValueStats(value: Byte): Unit = { + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += BYTE.defaultSize + count += 1 + } + + override def collectedStats: Array[Any] = Array[Any](lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } @@ -129,6 +158,15 @@ private[columnar] class ShortColumnStats extends ColumnStats { } } + def gatherValueStats(value: Short): Unit = { + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += SHORT.defaultSize + count += 1 + } + + override def collectedStats: Array[Any] = Array[Any](lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } @@ -147,6 +185,15 @@ private[columnar] class IntColumnStats extends ColumnStats { } } + def gatherValueStats(value: Int): Unit = { + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += INT.defaultSize + count += 1 + } + + override def collectedStats: Array[Any] = Array[Any](lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } @@ -165,6 +212,15 @@ private[columnar] class LongColumnStats extends ColumnStats { } } + def gatherValueStats(value: Long): Unit = { + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += LONG.defaultSize + count += 1 + } + + override def collectedStats: Array[Any] = Array[Any](lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } @@ -183,6 +239,15 @@ private[columnar] class FloatColumnStats extends ColumnStats { } } + def gatherValueStats(value: Float): Unit = { + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += FLOAT.defaultSize + count += 1 + } + + override def collectedStats: Array[Any] = Array[Any](lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } @@ -201,6 +266,15 @@ private[columnar] class DoubleColumnStats extends ColumnStats { } } + def gatherValueStats(value: Double): Unit = { + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += DOUBLE.defaultSize + count += 1 + } + + override def collectedStats: Array[Any] = Array[Any](lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } @@ -219,6 +293,15 @@ private[columnar] class StringColumnStats extends ColumnStats { } } + def gatherValueStats(value: UTF8String, size: Int): Unit = { + if (upper == null || value.compareTo(upper) > 0) upper = value.clone() + if (lower == null || value.compareTo(lower) < 0) lower = value.clone() + sizeInBytes += (size + 4) + count += 1 + } + + override def collectedStats: Array[Any] = Array[Any](lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } @@ -231,6 +314,13 @@ private[columnar] class BinaryColumnStats extends ColumnStats { } } + def gatherValueStats(value: Array[Byte], size: Int): Unit = { + sizeInBytes += (size + 4) + count += 1 + } + + override def collectedStats: Array[Any] = Array[Any](null, null, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes)) } @@ -252,6 +342,15 @@ private[columnar] class DecimalColumnStats(precision: Int, scale: Int) extends C } } + def gatherValueStats(value: Decimal, size: Int): Unit = { + if (upper == null || value.compareTo(upper) > 0) upper = value + if (lower == null || value.compareTo(lower) < 0) lower = value + sizeInBytes += size + count += 1 + } + + override def collectedStats: Array[Any] = Array[Any](lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } @@ -266,6 +365,25 @@ private[columnar] class ObjectColumnStats(dataType: DataType) extends ColumnStat } } + override def collectedStats: Array[Any] = null + override def collectedStatistics: GenericInternalRow = new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes)) } + +private[columnar] class OtherColumnStats() extends ColumnStats { + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { + throw new UnsupportedOperationException() + } + + def gatherValueStats(value: Object, size: Int): Unit = { + sizeInBytes += size + count += 1 + } + + override def collectedStats: Array[Any] = Array[Any](null, null, nullCount, count, sizeInBytes) + + override def collectedStatistics: GenericInternalRow = { + throw new UnsupportedOperationException() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index 14024d6c1055..bbb99223df43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -17,11 +17,16 @@ package org.apache.spark.sql.execution.columnar +import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeFormatter, CodeGenerator, UnsafeRowWriter} +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import org.apache.spark.sql.execution.vectorized.ColumnarBatch import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.types.UTF8String /** * An Iterator to walk through the InternalRows from a CachedBatch @@ -57,17 +62,56 @@ class MutableUnsafeRow(val writer: UnsafeRowWriter) extends BaseGenericInternalR override protected def genericGet(ordinal: Int): Any = throw new UnsupportedOperationException override def numFields: Int = throw new UnsupportedOperationException override def copy(): InternalRow = throw new UnsupportedOperationException + + def setDecimal(i: Int, v: Decimal, precision: Int, scale: Int): Unit = + writer.write(i, v, precision, scale) + def setUTF8String(i: Int, s: UTF8String): Unit = writer.write(i, s) + def setBinary(i: Int, b: Array[Byte]): Unit = writer.write(i, b) + def setArray(i: Int, a: ArrayData): Unit = { + val u = a.asInstanceOf[UnsafeArrayData] + val base = u.getBaseObject.asInstanceOf[Array[Byte]] + val offset = u.getBaseOffset - Platform.BYTE_ARRAY_OFFSET + if (offset > Integer.MAX_VALUE) { + throw new UnsupportedOperationException("Cannot write this array as it's too big.") + } + val size = u.getSizeInBytes + writer.write(i, base, offset.toInt, size) + } + def setMap(i: Int, m: MapData): Unit = { + val u = m.asInstanceOf[UnsafeMapData] + val base = u.getBaseObject.asInstanceOf[Array[Byte]] + val offset = u.getBaseOffset - Platform.BYTE_ARRAY_OFFSET + if (offset > Integer.MAX_VALUE) { + throw new UnsupportedOperationException("Cannot write this array as it's too big.") + } + val size = u.getSizeInBytes + writer.write(i, base, offset.toInt, size) + } + def setStruct(i: Int, r: InternalRow): Unit = { + val u = r.asInstanceOf[UnsafeRow] + val base = u.getBaseObject.asInstanceOf[Array[Byte]] + val offset = u.getBaseOffset - Platform.BYTE_ARRAY_OFFSET + if (offset > Integer.MAX_VALUE) { + throw new UnsupportedOperationException("Cannot write this array as it's too big.") + } + val size = u.getSizeInBytes + writer.write(i, base, offset.toInt, size) + } } /** * Generates bytecode for a [[ColumnarIterator]] for columnar cache. */ -object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarIterator] with Logging { +class GenerateColumnAccessor(conf: SparkConf) + extends CodeGenerator[Seq[DataType], ColumnarIterator] with Logging { protected def canonicalize(in: Seq[DataType]): Seq[DataType] = in protected def bind(in: Seq[DataType], inputSchema: Seq[Attribute]): Seq[DataType] = in protected def create(columnTypes: Seq[DataType]): ColumnarIterator = { + if (conf != null) { + return createItrForCacheColumnarBatch(conf, columnTypes) + } val ctx = newCodeGenContext() val numFields = columnTypes.size val (initializeAccessors, extractors) = columnTypes.zipWithIndex.map { case (dt, index) => @@ -159,6 +203,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; + import org.apache.spark.sql.execution.columnar.CachedBatchBytes; import org.apache.spark.sql.execution.columnar.MutableUnsafeRow; public SpecificColumnarIterator generate(Object[] references) { @@ -205,7 +250,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera return false; } - ${classOf[CachedBatch].getName} batch = (${classOf[CachedBatch].getName}) input.next(); + CachedBatchBytes batch = (CachedBatchBytes)input.next(); currentRow = 0; numRowsInBatch = batch.numRows(); for (int i = 0; i < columnIndexes.length; i ++) { @@ -228,8 +273,149 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera val code = CodeFormatter.stripOverlappingComments( new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) - logDebug(s"Generated ColumnarIterator:\n${CodeFormatter.format(code)}") + logDebug(s"Generated ColumnarIteratorForCachedBatchBytes:\n${CodeFormatter.format(code)}") CodeGenerator.compile(code).generate(Array.empty).asInstanceOf[ColumnarIterator] } + + protected def createItrForCacheColumnarBatch(conf: SparkConf, columnTypes: Seq[DataType]) + : ColumnarIterator = { + val ctx = newCodeGenContext() + val numFields = columnTypes.size + val confVar = ctx.addReferenceObj("conf", conf, classOf[SparkConf].getName) + + val setters = ctx.splitExpressions( + columnTypes.zipWithIndex.map { case (dt, index) => + val setter = dt match { + case NullType => + s"if (colInstances[$index].isNullAt(rowIdx)) { mutableRow.setNullAt($index); }\n" + case BooleanType => s"setBoolean($index, colInstances[$index].getBoolean(rowIdx))" + case ByteType => s"setByte($index, colInstances[$index].getByte(rowIdx))" + case ShortType => s"setShort($index, colInstances[$index].getShort(rowIdx))" + case IntegerType | DateType => s"setInt($index, colInstances[$index].getInt(rowIdx))" + case LongType | TimestampType => s"setLong($index, colInstances[$index].getLong(rowIdx))" + case FloatType => s"setFloat($index, colInstances[$index].getFloat(rowIdx))" + case DoubleType => s"setDouble($index, colInstances[$index].getDouble(rowIdx))" + case dt: DecimalType if dt.precision <= Decimal.MAX_INT_DIGITS => + s"setLong($index, (long)colInstances[$index].getInt(rowIdx))" + case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => + s"setLong($index, colInstances[$index].getLong(rowIdx))" + case dt: DecimalType => + val p = dt.precision + val s = dt.scale + s"setDecimal($index, colInstances[$index].getDecimal(rowIdx, $p, $s), $p, $s)" + case StringType => s"setUTF8String($index, colInstances[$index].getUTF8String(rowIdx))" + case BinaryType => s"setBinary($index, colInstances[$index].getBinary(rowIdx))" + case array: ArrayType => s"setArray($index, colInstances[$index].getArray(rowIdx))" + case t: MapType => s"setMap($index, colInstances[$index].getMap(rowIdx))" + case struct: StructType => + val s = struct.fields.length + s"setStruct($index, colInstances[$index].getStruct(rowIdx, $s))" + } + + dt match { + case NullType => setter + case dt: DecimalType if dt.precision > Decimal.MAX_LONG_DIGITS => + s""" + if (colInstances[$index].isNullAt(rowIdx)) { + mutableRow.setDecimal($index, null, ${dt.precision}, ${dt.scale}); + } else { + mutableRow.$setter; + } + """ + case _ => + s""" + if (colInstances[$index].isNullAt(rowIdx)) { + mutableRow.setNullAt($index); + } else { + mutableRow.$setter; + } + """ + } + }, + "apply", + Seq.empty + ) + + val codeBody = s""" + import scala.collection.Iterator; + import org.apache.spark.sql.types.DataType; + import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; + import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; + import org.apache.spark.sql.execution.columnar.MutableUnsafeRow; + import org.apache.spark.sql.execution.vectorized.ColumnVector; + import org.apache.spark.sql.execution.vectorized.OnHeapUnsafeColumnVector; + + public SpecificColumnarIterator generate(Object[] references) { + return new SpecificColumnarIterator(references); + } + + class SpecificColumnarIterator extends ${classOf[ColumnarIterator].getName} { + private ColumnVector[] colInstances; + private UnsafeRow unsafeRow = new UnsafeRow($numFields); + private BufferHolder bufferHolder = new BufferHolder(unsafeRow); + private UnsafeRowWriter rowWriter = new UnsafeRowWriter(bufferHolder, $numFields); + private MutableUnsafeRow mutableRow = null; + + private int rowIdx = 0; + private int numRowsInBatch = 0; + + private scala.collection.Iterator input = null; + private DataType[] columnTypes = null; + private int[] columnIndexes = null; + + ${ctx.declareMutableStates()} + + public SpecificColumnarIterator(Object[] references) { + ${ctx.initMutableStates()} + this.mutableRow = new MutableUnsafeRow(rowWriter); + } + + public void initialize(Iterator input, DataType[] columnTypes, int[] columnIndexes) { + this.input = input; + this.columnTypes = columnTypes; + this.columnIndexes = columnIndexes; + } + + ${ctx.declareAddedFunctions()} + + public boolean hasNext() { + if (rowIdx < numRowsInBatch) { + return true; + } + if (!input.hasNext()) { + return false; + } + + ${classOf[CachedColumnarBatch].getName} cachedBatch = + (${classOf[CachedColumnarBatch].getName}) input.next(); + ${classOf[ColumnarBatch].getName} batch = cachedBatch.columnarBatch(); + rowIdx = 0; + numRowsInBatch = cachedBatch.getNumRows(); + colInstances = new ColumnVector[columnIndexes.length]; + for (int i = 0; i < columnIndexes.length; i ++) { + colInstances[i] = batch.column(columnIndexes[i]); + ((OnHeapUnsafeColumnVector)colInstances[i]).decompress($confVar); + } + + return hasNext(); + } + + public InternalRow next() { + bufferHolder.reset(); + rowWriter.zeroOutNullBytes(); + ${setters} + unsafeRow.setTotalSize(bufferHolder.totalSize()); + rowIdx += 1; + return unsafeRow; + } + }""" + + val code = CodeFormatter.stripOverlappingComments( + new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) + logDebug(s"Generated ColumnarIteratorForCachedColumnarBatch:\n${CodeFormatter.format(code)}") + + CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[ColumnarIterator] + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnarBatch.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnarBatch.scala new file mode 100644 index 000000000000..e6fad0a9d80b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnarBatch.scala @@ -0,0 +1,306 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import org.apache.spark.SparkConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, OnHeapUnsafeColumnVector} +import org.apache.spark.sql.types._ +import org.apache.spark.storage.StorageLevel +import org.apache.spark.storage.StorageLevel._ + + +/** + * A helper class to expose the scala iterator to Java. + */ +abstract class ColumnarBatchIterator extends Iterator[ColumnarBatch] + + +/** + * Generate code to batch [[InternalRow]]s into [[ColumnarBatch]]es. + */ +class GenerateColumnarBatch( + schema: StructType, + batchSize: Int, + storageLevel: StorageLevel, + conf: SparkConf) + extends CodeGenerator[Iterator[InternalRow], Iterator[CachedColumnarBatch]] { + + protected def canonicalize(in: Iterator[InternalRow]): Iterator[InternalRow] = in + + protected def bind( + in: Iterator[InternalRow], + inputSchema: Seq[Attribute]): Iterator[InternalRow] = { + in + } + + protected def create(rowIterator: Iterator[InternalRow]): Iterator[CachedColumnarBatch] = { + import scala.collection.JavaConverters._ + val ctx = newCodeGenContext() + val columnStatsCls = classOf[ColumnStats].getName + val rowVar = ctx.freshName("row") + val batchVar = ctx.freshName("columnarBatch") + val rowNumVar = ctx.freshName("rowNum") + val numBytesVar = ctx.freshName("bytesInBatch") + ctx.addMutableState("long", numBytesVar, s"$numBytesVar = 0;") + val rowIterVar = ctx.addReferenceObj( + "rowIterator", rowIterator.asJava, classOf[java.util.Iterator[_]].getName) + val schemas = StructType( + schema.fields.map(s => StructField(s.name, + s.dataType match { + case udt: UserDefinedType[_] => udt.sqlType + case other => other + }, s.nullable)) + ) + val schemaVar = ctx.addReferenceObj("schema", schemas, classOf[StructType].getName) + val maxNumBytes = ColumnBuilder.MAX_BATCH_SIZE_IN_BYTE + val numColumns = schema.fields.length + + val colStatVars = (0 to numColumns - 1).map(i => ctx.freshName("colStat" + i)) + val colStatCode = ctx.splitExpressions( + (schemas.fields zip colStatVars).zipWithIndex.map { + case ((field, varName), i) => + val (columnStatsCls, arg) = field.dataType match { + case BooleanType => (classOf[BooleanColumnStats].getName, "()") + case ByteType => (classOf[ByteColumnStats].getName, "()") + case ShortType => (classOf[ShortColumnStats].getName, "()") + case IntegerType | DateType => (classOf[IntColumnStats].getName, "()") + case LongType | TimestampType => (classOf[LongColumnStats].getName, "()") + case FloatType => (classOf[FloatColumnStats].getName, "()") + case DoubleType => (classOf[DoubleColumnStats].getName, "()") + case StringType => (classOf[StringColumnStats].getName, "()") + case BinaryType => (classOf[BinaryColumnStats].getName, "()") + case dt: DecimalType => + (classOf[DecimalColumnStats].getName, s"(${dt.precision}, ${dt.scale})") + case dt => (classOf[OtherColumnStats].getName, "()") + } + ctx.addMutableState(columnStatsCls, varName, "") + s"$varName = new $columnStatsCls$arg; statsArray[$i] = $varName;\n" + }, + "apply", + Seq.empty + ) + + val populateColumnVectorsCode = ctx.splitExpressions( + (schemas.fields zip colStatVars).zipWithIndex.map { + case ((field, colStatVar), i) => + GenerateColumnarBatch.putColumnCode(ctx, field.dataType, field.nullable, + batchVar, rowVar, rowNumVar, colStatVar, i, numBytesVar).trim + "\n" + }, + "apply", + Seq(("InternalRow", rowVar), ("ColumnarBatch", batchVar), ("int", rowNumVar)) + ) + + val confVar = ctx.addReferenceObj("conf", conf, classOf[SparkConf].getName) + val compress = if (!GenerateColumnarBatch.isCompress(storageLevel)) "" else s""" + for (int i = 0; i < $numColumns; i++) { + ((OnHeapUnsafeColumnVector)$batchVar.column(i)).compress($confVar); + } + """ + + val code = s""" + import org.apache.spark.memory.MemoryMode; + import org.apache.spark.sql.catalyst.InternalRow; + import org.apache.spark.sql.execution.columnar.CachedColumnarBatch; + import org.apache.spark.sql.execution.columnar.GenerateColumnarBatch; + import org.apache.spark.sql.execution.vectorized.ColumnarBatch; + import org.apache.spark.sql.execution.vectorized.ColumnVector; + import org.apache.spark.sql.execution.vectorized.OnHeapUnsafeColumnVector; + + public GeneratedColumnarBatchIterator generate(Object[] references) { + return new GeneratedColumnarBatchIterator(references); + } + + class GeneratedColumnarBatchIterator extends ${classOf[ColumnarBatchIterator].getName} { + private Object[] references; + ${ctx.declareMutableStates()} + + public GeneratedColumnarBatchIterator(Object[] references) { + this.references = references; + ${ctx.initMutableStates()} + } + + ${ctx.declareAddedFunctions()} + + $columnStatsCls[] statsArray = new $columnStatsCls[$numColumns]; + private void allocateColumnStats() { + ${colStatCode.trim} + } + + @Override + public boolean hasNext() { + return $rowIterVar.hasNext(); + } + + @Override + public CachedColumnarBatch next() { + ColumnarBatch $batchVar = + ColumnarBatch.allocate($schemaVar, MemoryMode.ON_HEAP_UNSAFE, $batchSize); + allocateColumnStats(); + int $rowNumVar = 0; + $numBytesVar = 0; + while ($rowIterVar.hasNext() && $rowNumVar < $batchSize && $numBytesVar < $maxNumBytes) { + InternalRow $rowVar = (InternalRow) $rowIterVar.next(); + $populateColumnVectorsCode + $rowNumVar += 1; + } + $batchVar.setNumRows($rowNumVar); + ${compress.trim} + return CachedColumnarBatch.apply( + $batchVar, GenerateColumnarBatch.generateStats(statsArray)); + } + } + """ + val formattedCode = CodeFormatter.stripOverlappingComments( + new CodeAndComment(code, ctx.getPlaceHolderToComments())) + CodeGenerator.compile(formattedCode).generate(ctx.references.toArray) + .asInstanceOf[Iterator[CachedColumnarBatch]] + } + +} + + +private[sql] object GenerateColumnarBatch { + + def compressStorageLevel(storageLevel: StorageLevel, useCompression: Boolean): StorageLevel = { + if (!useCompression) return storageLevel + storageLevel match { + case MEMORY_ONLY => MEMORY_ONLY_SER + case MEMORY_ONLY_2 => MEMORY_ONLY_SER_2 + case MEMORY_AND_DISK => MEMORY_AND_DISK_SER + case MEMORY_AND_DISK_2 => MEMORY_AND_DISK_SER_2 + case sl => sl + } + } + + def isCompress(storageLevel: StorageLevel) : Boolean = { + (storageLevel == MEMORY_ONLY_SER || storageLevel == MEMORY_ONLY_SER_2 || + storageLevel == MEMORY_AND_DISK_SER || storageLevel == MEMORY_AND_DISK_SER_2) + } + + private val typeToName = Map[AbstractDataType, String]( + BooleanType -> "boolean", + ByteType -> "byte", + ShortType -> "short", + IntegerType -> "int", + LongType -> "long", + FloatType -> "float", + DoubleType -> "double", + DateType -> "int", + TimestampType -> "long", + StringType -> "UTF8String", + BinaryType -> "Binary" + ) + + def putColumnCode(ctx: CodegenContext, dt: DataType, nullable: Boolean, batchVar: String, + rowVar: String, rowNumVar: String, colStatVar: String, colNum: Int, numBytesVar: String) + : String = { + val colVar = s"$batchVar.column($colNum)" + val body = dt match { + case t if ctx.isPrimitiveType(dt) => + val typeName = GenerateColumnarBatch.typeToName(dt) + val put = "put" + typeName.capitalize + val get = "get" + typeName.capitalize + s""" + $typeName val = $rowVar.$get($colNum); + $colVar.$put($rowNumVar, val); + $numBytesVar += ${dt.defaultSize}; + $colStatVar.gatherValueStats(val); + """ + case StringType | BinaryType => + val typeName = GenerateColumnarBatch.typeToName(dt) + val typeDeclName = dt match { + case StringType => "UTF8String" + case BinaryType => "byte[]" + } + val put = "put" + typeName.capitalize + val get = "get" + typeName.capitalize + s""" + $typeDeclName val = $rowVar.$get($colNum); + int size = $colVar.$put($rowNumVar, val); + $numBytesVar += size; + $colStatVar.gatherValueStats(val, size); + """ + case NullType => + return s""" + if ($rowVar.isNullAt($colNum)) { + $colVar.putNull($rowNumVar); + } else { + $colVar.putNotNull($rowNumVar); + } + $numBytesVar += 1; + $colStatVar.gatherValueStats(null, 1); + """ + case dt: DecimalType => + val precision = dt.precision + val scale = dt.scale + s""" + Decimal val = $rowVar.getDecimal($colNum, $precision, $scale); + int size = $colVar.putDecimal($rowNumVar, val, $precision); + $numBytesVar += size; + $colStatVar.gatherValueStats(val, size); + """ + case array: ArrayType => + s""" + ArrayData val = $rowVar.getArray($colNum); + int size = $colVar.putArray($rowNumVar, val); + $numBytesVar += size; + $colStatVar.gatherValueStats(val, size); + """ + case t: MapType => + s""" + MapData val = $rowVar.getMap($colNum); + int size = $colVar.putMap($rowNumVar, val); + $numBytesVar += size; + $colStatVar.gatherValueStats(val, size); + """ + case struct: StructType => + s""" + InternalRow val = $rowVar.getStruct($colNum, ${struct.length}); + int size = $colVar.putStruct($rowNumVar,val); + $numBytesVar += size; + $colStatVar.gatherValueStats(val, size); + """ + case _ => + throw new UnsupportedOperationException("Unsupported data type " + dt.simpleString); + } + if (nullable) { + s""" + if ($rowVar.isNullAt($colNum)) { + $colVar.putNull($rowNumVar); + $colStatVar.gatherNullStats(); + } else { + ${body.trim} + } + """ + } else { + s""" + { + ${body.trim} + } + """ + } + } + + def generateStats(columnStats: Array[ColumnStats]): InternalRow = { + val array = columnStats.map(_.collectedStats).flatten + InternalRow.fromSeq(array) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 37bd95e73778..cf706dde1b83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -27,10 +27,45 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.vectorized.ColumnarBatch +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType import org.apache.spark.storage.StorageLevel import org.apache.spark.util.LongAccumulator +/** + * An abstract representation of a cached batch of rows. + */ +private[columnar] trait CachedBatch { + val stats: InternalRow + def getNumRows(): Int +} + + +/** + * A cached batch of rows stored as a list of byte arrays, one for each column. + * + * @param numRows The total number of rows in this batch + * @param buffers The serialized buffers for serialized columns + * @param stats The stat of columns + */ +private[columnar] case class CachedBatchBytes( + numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow) + extends CachedBatch { + def getNumRows(): Int = numRows +} + + +/** + * A cached batch of rows stored as a [[ColumnarBatch]]. + */ +private[columnar] case class CachedColumnarBatch(columnarBatch: ColumnarBatch, stats: InternalRow) + extends CachedBatch { + def getNumRows(): Int = columnarBatch.numRows() +} + + object InMemoryRelation { def apply( useCompression: Boolean, @@ -43,15 +78,11 @@ object InMemoryRelation { /** - * CachedBatch is a cached batch of rows. + * Container for a physical plan that should be cached in memory. * - * @param numRows The total number of rows in this batch - * @param buffers The buffers for serialized columns - * @param stats The stat of columns + * This batches the rows from that plan into [[CachedBatch]]es that are later consumed by + * [[InMemoryTableScanExec]]. */ -private[columnar] -case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow) - case class InMemoryRelation( output: Seq[Attribute], useCompression: Boolean, @@ -63,6 +94,14 @@ case class InMemoryRelation( val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator) extends logical.LeafNode with MultiInstanceRelation { + /** + * If true, store the input rows using [[CachedColumnarBatch]]es, which are generally faster. + * If false, store the input rows using [[CachedBatchBytes]]. + */ + private[columnar] val useColumnarBatches: Boolean = { + child.sqlContext.conf.getConf(SQLConf.CACHE_CODEGEN) + } + override protected def innerChildren: Seq[SparkPlan] = Seq(child) override def producedAttributes: AttributeSet = outputSet @@ -79,23 +118,33 @@ case class InMemoryRelation( } } - // If the cached column buffers were not passed in, we calculate them in the constructor. - // As in Spark, the actual work of caching is lazy. - if (_cachedColumnBuffers == null) { - buildBuffers() - } - - def recache(): Unit = { - _cachedColumnBuffers.unpersist() - _cachedColumnBuffers = null - buildBuffers() + /** + * Batch the input rows into [[CachedBatch]]es. + */ + private def buildColumnBuffers: RDD[CachedBatch] = { + val buffers = + if (useColumnarBatches) { + buildColumnarBatches() + } else { + buildColumnBytes() + } + buffers.setName( + tableName.map { n => s"In-memory table $n" } + .getOrElse(StringUtils.abbreviate(child.toString, 1024))) + buffers.asInstanceOf[RDD[CachedBatch]] } - private def buildBuffers(): Unit = { + /** + * Batch the input rows into [[CachedBatchBytes]] built using [[ColumnBuilder]]s. + * + * This handles complex types and compression, but is more expensive than + * [[buildColumnarBatches]], which generates code to build the buffers. + */ + private def buildColumnBytes(): RDD[CachedBatchBytes] = { val output = child.output - val cached = child.execute().mapPartitionsInternal { rowIterator => - new Iterator[CachedBatch] { - def next(): CachedBatch = { + child.execute().mapPartitionsInternal { rowIterator => + new Iterator[CachedBatchBytes] { + def next(): CachedBatchBytes = { val columnBuilders = output.map { attribute => ColumnBuilder(attribute.dataType, batchSize, attribute.name, useCompression) }.toArray @@ -130,7 +179,7 @@ case class InMemoryRelation( val stats = InternalRow.fromSeq(columnBuilders.map(_.columnStats.collectedStatistics) .flatMap(_.values)) - CachedBatch(rowCount, columnBuilders.map { builder => + CachedBatchBytes(rowCount, columnBuilders.map { builder => JavaUtils.bufferToArray(builder.build()) }, stats) } @@ -138,11 +187,47 @@ case class InMemoryRelation( def hasNext: Boolean = rowIterator.hasNext } }.persist(storageLevel) + } - cached.setName( - tableName.map(n => s"In-memory table $n") - .getOrElse(StringUtils.abbreviate(child.toString, 1024))) - _cachedColumnBuffers = cached + /** + * Batch the input rows using [[ColumnarBatch]]es. + * + * Compared with [[buildColumnBytes]], this provides a faster implementation of memory + * scan because both the read path and the write path are generated. + * However, this does not compress data for now + */ + private def buildColumnarBatches(): RDD[CachedColumnarBatch] = { + val schema = StructType.fromAttributes(child.output) + val newStorageLevel = GenerateColumnarBatch.compressStorageLevel(storageLevel, useCompression) + val conf = child.sqlContext.sparkSession.sparkContext.conf + child.execute().mapPartitionsInternal { rows => + new GenerateColumnarBatch(schema, batchSize, newStorageLevel, conf).generate(rows).map { + cachedColumnarBatch => { + var i = 0 + var totalSize = 0L + while (i < cachedColumnarBatch.columnarBatch.numCols()) { + totalSize += cachedColumnarBatch.stats.getLong(4 + i * 5) + i += 1 + } + batchStats.add(totalSize) + cachedColumnarBatch + } + } + }.persist(storageLevel) + } + + // If the cached column buffers were not passed in, we calculate them in the constructor. + // As in Spark, the actual work of caching is lazy. + if (_cachedColumnBuffers == null) { + _cachedColumnBuffers = buildColumnBuffers + } + + def recache(): Unit = { + if (_cachedColumnBuffers != null) { + _cachedColumnBuffers.unpersist() + _cachedColumnBuffers = null + } + _cachedColumnBuffers = buildColumnBuffers } def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = { @@ -163,7 +248,15 @@ case class InMemoryRelation( batchStats).asInstanceOf[this.type] } - def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers + /** + * Return lazily cached batches of rows in the original plan. + */ + def cachedColumnBuffers: RDD[CachedBatch] = { + if (_cachedColumnBuffers == null) { + _cachedColumnBuffers = buildColumnBuffers + } + _cachedColumnBuffers + } override protected def otherCopyArgs: Seq[AnyRef] = Seq(_cachedColumnBuffers, batchStats) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 9028caa446e8..440410683f5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.execution.LeafExecNode +import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.UserDefinedType @@ -32,12 +32,57 @@ case class InMemoryTableScanExec( attributes: Seq[Attribute], predicates: Seq[Expression], @transient relation: InMemoryRelation) - extends LeafExecNode { + extends LeafExecNode with ColumnarBatchScan { + + override val columnIndexes = + attributes.map(a => relation.output.map(o => o.exprId).indexOf(a.exprId)).toArray + + override val inMemoryTableScan = this + + override val supportCodegen: Boolean = relation.useColumnarBatches + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + if (relation.useColumnarBatches) { + val schema = relation.partitionStatistics.schema + val schemaIndex = schema.zipWithIndex + val buffers = relation.cachedColumnBuffers.asInstanceOf[RDD[CachedColumnarBatch]] + val prunedBuffers = if (inMemoryPartitionPruningEnabled) { + buffers.mapPartitionsInternal { cachedColumnarBatchIterator => + val partitionFilter = newPredicate( + partitionFilters.reduceOption(And).getOrElse(Literal(true)), schema) + + // Do partition batch pruning if enabled + cachedColumnarBatchIterator.filter { cachedColumnarBatch => + if (!partitionFilter.eval(cachedColumnarBatch.stats)) { + def statsString: String = schemaIndex.map { + case (a, i) => + val value = cachedColumnarBatch.stats.get(i, a.dataType) + s"${a.name}: $value" + }.mkString(", ") + logInfo(s"Skipping partition based on stats $statsString") + false + } else { + true + } + } + } + } else { + buffers + } + + // HACK ALERT: This is actually an RDD[CachedColumnarBatch]. + // We're taking advantage of Scala's type erasure here to pass these batches along. + Seq(prunedBuffers.map(_.columnarBatch).asInstanceOf[RDD[InternalRow]]) + } else { + Seq() + } + } override protected def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) override def output: Seq[Attribute] = attributes @@ -130,7 +175,9 @@ case class InMemoryTableScanExec( val schema = relation.partitionStatistics.schema val schemaIndex = schema.zipWithIndex val relOutput: AttributeSeq = relation.output + assert(relation.cachedColumnBuffers != null) val buffers = relation.cachedColumnBuffers + val conf = if (relation.useColumnarBatches) sqlContext.sparkContext.conf else null buffers.mapPartitionsWithIndexInternal { (index, cachedBatchIterator) => val partitionFilter = newPredicate( @@ -169,7 +216,7 @@ case class InMemoryTableScanExec( if (enableAccumulators) { readBatches.add(1) } - numOutputRows += batch.numRows + numOutputRows += batch.getNumRows() batch } @@ -177,7 +224,7 @@ case class InMemoryTableScanExec( case udt: UserDefinedType[_] => udt.sqlType case other => other }.toArray - val columnarIterator = GenerateColumnAccessor.generate(columnTypes) + val columnarIterator = new GenerateColumnAccessor(conf).generate(columnTypes) columnarIterator.initialize(withMetrics, columnTypes, requestedColumnIndices.toArray) if (enableAccumulators && columnarIterator.hasNext) { readPartitions.add(1) @@ -185,4 +232,5 @@ case class InMemoryTableScanExec( columnarIterator } } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 645b0fa13ee3..f1a54d424ab6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -90,6 +90,21 @@ object SQLConf { .booleanConf .createWithDefault(true) + val CACHE_CODEGEN = SQLConfigBuilder("spark.sql.inMemoryColumnarStorage.codegen") + .internal() + .doc("When true, use generated code to build column batches for caching. This is only " + + "supported for basic types and improves caching performance for such types.") + .booleanConf + .createWithDefault(true) + + val CACHE_COMPRESSION_CODEC = + SQLConfigBuilder("spark.sql.inMemoryColumnarStorage.compression.codec") + .internal() + .doc("Sets the compression codec use when columnar caching is compressed.") + .stringConf + .transform(_.toLowerCase()) + .createWithDefault("lz4") + val PREFER_SORTMERGEJOIN = SQLConfigBuilder("spark.sql.join.preferSortMergeJoin") .internal() .doc("When true, prefer sort merge join over shuffle hash join.") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CacheBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CacheBenchmark.scala new file mode 100644 index 000000000000..0477bf08f28c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CacheBenchmark.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import org.apache.spark.SparkEnv +import org.apache.spark.sql.{DataFrame, QueryTest, Row} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.Benchmark + + +class CacheBenchmark extends BenchmarkBase { + + ignore("cache with randomized keys - both build and read paths") { + benchmarkRandomizedKeys(size = 16 * 1024 * 1024, readPathOnly = false) + + /* + OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64 + Intel Xeon E3-12xx v2 (Ivy Bridge) + Cache random keys: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ---------------------------------------------------------------------------------------------- + cache = T columnarBatch = F compress = T 7211 / 7366 2.3 429.8 1.0X + cache = T columnarBatch = F compress = F 2381 / 2460 7.0 141.9 3.0X + cache = F 137 / 140 122.7 8.1 52.7X + cache = T columnarBatch = T compress = T 2109 / 2252 8.0 125.7 3.4X + cache = T columnarBatch = T compress = F 1126 / 1184 14.9 67.1 6.4X + */ + } + + ignore("cache with randomized keys - read path only") { + benchmarkRandomizedKeys(size = 64 * 1024 * 1024, readPathOnly = true) + + /* + OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64 + Intel Xeon E3-12xx v2 (Ivy Bridge) + Cache random keys: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ---------------------------------------------------------------------------------------------- + cache = T columnarBatch = F compress = T 1615 / 1655 41.5 24.1 1.0X + cache = T columnarBatch = F compress = F 1603 / 1690 41.9 23.9 1.0X + cache = F 444 / 449 151.3 6.6 3.6X + cache = T columnarBatch = T compress = T 1404 / 1526 47.8 20.9 1.2X + cache = T columnarBatch = T compress = F 116 / 125 579.0 1.7 13.9X + */ + } + + /** + * Call clean on a [[DataFrame]] after deleting all existing temporary files. + */ + private def clean(df: DataFrame): Unit = { + df.sparkSession.sparkContext.parallelize(1 to 10, 10).foreach { _ => + SparkEnv.get.blockManager.diskBlockManager.getAllFiles().foreach { dir => + dir.delete() + } + } + } + + /** + * Benchmark caching randomized keys created from a range. + * + * NOTE: When running this benchmark, you will get a lot of WARN logs complaining that the + * shuffle files do not exist. This is intentional; we delete the shuffle files manually + * after every call to `collect` to avoid the next run to reuse shuffle files written by + * the previous run. + */ + private def benchmarkRandomizedKeys(size: Int, readPathOnly: Boolean): Unit = { + val numIters = 10 + val benchmark = new Benchmark("Cache random keys", size) + sparkSession.range(size) + .selectExpr("id", "floor(rand() * 10000) as k") + .createOrReplaceTempView("test") + val query = "select count(k), count(id) from test" + + /** + * Add a benchmark case, optionally specifying whether to cache the dataset. + */ + def addBenchmark(name: String, cache: Boolean, params: Map[String, String] = Map()): Unit = { + val ds = sparkSession.sql(query) + var dsResult: DataFrame = null + val defaults = params.keys.flatMap { k => sparkSession.conf.getOption(k).map((k, _)) } + def prepare(): Unit = { + clean(ds) + params.foreach { case (k, v) => sparkSession.conf.set(k, v) } + if (cache && readPathOnly) { + sparkSession.sql("cache table test") + } + } + def cleanup(): Unit = { + clean(dsResult) + defaults.foreach { case (k, v) => sparkSession.conf.set(k, v) } + sparkSession.catalog.clearCache() + } + benchmark.addCase(name, numIters, prepare, cleanup) { _ => + if (cache && !readPathOnly) { + sparkSession.sql("cache table test") + } + dsResult = sparkSession.sql(query) + dsResult.collect + } + } + + // All of these are codegen = T hashmap = T + sparkSession.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") + + // Benchmark cases: + // (1) Caching with compression + // (2) Caching without compression + // (3) No caching + // (4) Caching using column batch with compression + // (5) Caching using column batch without compression + addBenchmark("cache = T columnarBatch = F compress = T", cache = true, Map( + SQLConf.CACHE_CODEGEN.key -> "false", + SQLConf.COMPRESS_CACHED.key -> "true" + )) + addBenchmark("cache = T columnarBatch = F compress = F", cache = true, Map( + SQLConf.CACHE_CODEGEN.key -> "false", + SQLConf.COMPRESS_CACHED.key -> "false" + )) + addBenchmark("cache = F", cache = false) + addBenchmark("cache = T columnarBatch = T compress = T", cache = true, Map( + SQLConf.CACHE_CODEGEN.key -> "true", + SQLConf.COMPRESS_CACHED.key -> "true" + )) + addBenchmark("cache = T columnarBatch = T compress = F", cache = true, Map( + SQLConf.CACHE_CODEGEN.key -> "true", + SQLConf.COMPRESS_CACHED.key -> "false" + )) + benchmark.run() + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index f355a5200ce2..15323e7c85a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -20,7 +20,9 @@ package org.apache.spark.sql.execution.columnar import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} +import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, QueryTest, Row} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, GenericInternalRow} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ @@ -32,18 +34,24 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { setupTestData() - private def cachePrimitiveTest(data: DataFrame, dataType: String) { + def cachePrimitiveTest(data: DataFrame, dataType: String) { data.createOrReplaceTempView(s"testData$dataType") - val storageLevel = MEMORY_ONLY - val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan - val inMemoryRelation = InMemoryRelation(useCompression = true, 5, storageLevel, plan, None) - - assert(inMemoryRelation.cachedColumnBuffers.getStorageLevel == storageLevel) - inMemoryRelation.cachedColumnBuffers.collect().head match { - case _: CachedBatch => - case other => fail(s"Unexpected cached batch type: ${other.getClass.getName}") + val useColumnBatches = true + withSQLConf(SQLConf.CACHE_CODEGEN.key -> useColumnBatches.toString) { + Seq(MEMORY_ONLY, MEMORY_ONLY_SER).map { storageLevel => + val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan + val inMemoryRelation = InMemoryRelation(useCompression = false, 5, storageLevel, plan, None) + + assert(inMemoryRelation.useColumnarBatches == useColumnBatches) + assert(inMemoryRelation.cachedColumnBuffers.getStorageLevel == storageLevel) + inMemoryRelation.cachedColumnBuffers.collect().head match { + case _: CachedColumnarBatch => assert(useColumnBatches) + case _: CachedBatchBytes => assert(!useColumnBatches) + case other => fail(s"Unexpected cached batch type: ${other.getClass.getName}") + } + checkAnswer(inMemoryRelation, data.collect().toSeq) + } } - checkAnswer(inMemoryRelation, data.collect().toSeq) } private def testPrimitiveType(nullability: Boolean): Unit = { @@ -69,7 +77,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { cachePrimitiveTest(spark.createDataFrame(rdd, schema), "primitivesDateTimeStamp") } - private def tesNonPrimitiveType(nullability: Boolean): Unit = { + private def testNonPrimitiveType(nullability: Boolean): Unit = { val struct = StructType(StructField("f1", FloatType, false) :: StructField("f2", ArrayType(BooleanType), true) :: Nil) val schema = StructType(Seq( @@ -103,11 +111,62 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { val rddNull = spark.sparkContext.parallelize((1 to 10).map(i => Row(null))) cachePrimitiveTest(spark.createDataFrame(rddNull, schemaNull), "Null") - tesNonPrimitiveType(true) + testNonPrimitiveType(true) } test("non-primitive type with nullability:false") { - tesNonPrimitiveType(false) + testNonPrimitiveType(false) + } + + test("all data type w && w/o nullability") { + // all primitives + Seq(true, false).map { nullability => + val dataTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DateType, TimestampType, DecimalType(25, 5), DecimalType(6, 5)) + val schema = StructType(dataTypes.zipWithIndex.map { case (dataType, index) => + StructField(s"col$index", dataType, nullability) + }) + val rdd = spark.sparkContext.parallelize((1 to 10).map(i => Row( + if (nullability && i % 3 == 0) null else if (i % 2 == 0) true else false, + if (nullability && i % 3 == 0) null else i.toByte, + if (nullability && i % 3 == 0) null else i.toShort, + if (nullability && i % 3 == 0) null else i.toInt, + if (nullability && i % 3 == 0) null else i.toLong, + if (nullability && i % 3 == 0) null else (i + 0.25).toFloat, + if (nullability && i % 3 == 0) null else (i + 0.75).toDouble, + if (nullability && i % 3 == 0) null else new Date(i), + if (nullability && i % 3 == 0) null else new Timestamp(i * 1000000L), + if (nullability && i % 3 == 0) null else BigDecimal(Long.MaxValue.toString + ".12345"), + if (nullability && i % 3 == 0) null + else new java.math.BigDecimal(s"${i % 9 + 1}" + ".23456") + ))) + cachePrimitiveTest(spark.createDataFrame(rdd, schema), "primitivesDateTimeStamp") + } + + val schemaNull = StructType(Seq(StructField("col", NullType, true))) + val rddNull = spark.sparkContext.parallelize((1 to 10).map(i => Row(null))) + cachePrimitiveTest(spark.createDataFrame(rddNull, schemaNull), "Null") + + Seq(true, false).map { nullability => + val struct = StructType(StructField("f1", FloatType, false) :: + StructField("f2", ArrayType(BooleanType), true) :: Nil) + val schema = StructType(Seq( + StructField("col0", StringType, nullability), + StructField("col1", ArrayType(IntegerType), nullability), + StructField("col2", ArrayType(ArrayType(IntegerType)), nullability), + StructField("col3", MapType(StringType, IntegerType), nullability), + StructField("col4", struct, nullability) + )) + val rdd = spark.sparkContext.parallelize((1 to 10).map(i => Row( + if (nullability && i % 3 == 0) null else s"str${i}: test cache.", + if (nullability && i % 3 == 0) null else (i * 100 to i * 100 + i).toArray, + if (nullability && i % 3 == 0) null + else Array(Array(i, i + 1), Array(i * 100 + 1, i * 100, i * 100 + 2)), + if (nullability && i % 3 == 0) null else (i to i + i).map(j => s"key$j" -> j).toMap, + if (nullability && i % 3 == 0) null else Row((i + 0.25).toFloat, Seq(true, false, null)) + ))) + cachePrimitiveTest(spark.createDataFrame(rdd, schema), "StringArrayMapStruct") + } } test("simple columnar query") { @@ -242,7 +301,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { val allColumns = fields.map(_.name).mkString(",") val schema = StructType(fields) - // Create an RDD for the schema + // Create a RDD for the schema val rdd = sparkContext.parallelize((1 to 10000), 10).map { i => Row( @@ -310,12 +369,82 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-14138: Generated SpecificColumnarIterator can exceed JVM size limit for cached DF") { val length1 = 3999 val columnTypes1 = List.fill(length1)(IntegerType) - val columnarIterator1 = GenerateColumnAccessor.generate(columnTypes1) + val columnarIterator1 = new GenerateColumnAccessor(null).generate(columnTypes1) // SPARK-16664: the limit of janino is 8117 val length2 = 8117 val columnTypes2 = List.fill(length2)(IntegerType) - val columnarIterator2 = GenerateColumnAccessor.generate(columnTypes2) + val columnarIterator2 = new GenerateColumnAccessor(null).generate(columnTypes2) + } + + test("ColumnarBatch with many columns") { + val length1 = 9000 + val schema = StructType((1 to length1).map { case i => + StructField(s"col$i", IntegerType, true) + }) + val cachedBatch1 = new GenerateColumnarBatch(schema, 10000, MEMORY_ONLY, sparkConf). + generate(Iterator.single(new GenericInternalRow((1 to length1).toArray[Any]))) + + val length2 = 9000 + val columnTypes2 = List.fill(length2)(IntegerType) + val columnarIterator2 = new GenerateColumnAccessor(sparkConf).generate(columnTypes2) + } + + test("access columns in CachedColumnarBatch without whole stage codegen") { + // whole stage codegen is not applied to a row with more than WHOLESTAGE_MAX_NUM_FIELDS fields + val dummySeq = Seq.range(0, 20) + val dummySchemas = dummySeq.map(i => StructField(s"d$i" + i, IntegerType, true)) + withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "20") { + val data = Seq(null, true, 1.toByte, 3.toShort, 7, 15.toLong, + 31.25.toFloat, 63.75, new Date(127), new Timestamp(255000000L), null) + val dataTypes = Seq(NullType, BooleanType, ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DateType, TimestampType, IntegerType) + val schemas = dataTypes.zipWithIndex.map { case (dataType, index) => + StructField(s"col$index", dataType, true) + } + val rdd = sparkContext.makeRDD(Seq(Row.fromSeq(data ++ dummySeq))) + val df = spark.createDataFrame(rdd, StructType(schemas ++ dummySchemas)) + val row = df.persist.take(1).apply(0) + checkAnswer(df, row) + } + + withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "20") { + val data = Seq(BigDecimal(Long.MaxValue.toString + ".12345"), + new java.math.BigDecimal("1234567890.12345"), + new java.math.BigDecimal("1.23456"), + "test123" + ) + val schemas = Seq( + StructField("col0", DecimalType(25, 5), true), + StructField("col1", DecimalType(15, 5), true), + StructField("col2", DecimalType(6, 5), true), + StructField("col3", StringType, true) + ) + val rdd = sparkContext.makeRDD(Seq(Row.fromSeq(data ++ dummySeq))) + val df = spark.createDataFrame(rdd, StructType(schemas ++ dummySchemas)) + val row = df.persist.take(1).apply(0) + checkAnswer(df, row) + } + + withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "20") { + val data = Seq((1 to 10).toArray, + Array(Array(10, 11), Array(100, 111, 123)), + Map("key1" -> 111, "key2" -> 222), + Row(1.25.toFloat, Seq(true, false, null)) + ) + val struct = StructType(StructField("f1", FloatType, false) :: + StructField("f2", ArrayType(BooleanType), true) :: Nil) + val schemas = Seq( + StructField("col0", ArrayType(IntegerType), true), + StructField("col1", ArrayType(ArrayType(IntegerType)), true), + StructField("col2", MapType(StringType, IntegerType), true), + StructField("col3", struct, true) + ) + val rdd = sparkContext.makeRDD(Seq(Row.fromSeq(data ++ dummySeq))) + val df = spark.createDataFrame(rdd, StructType(schemas ++ dummySchemas)) + val row = df.persist.take(1).apply(0) + checkAnswer(df, row) + } } test("SPARK-17549: cached table size should be correctly calculated") { @@ -390,4 +519,41 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } } + test("InMemoryRelation builds the correct buffers with simple schemas") { + testColumnBatches(useColumnBatches = true, useComplexSchema = false) + testColumnBatches(useColumnBatches = false, useComplexSchema = false) + } + + test("InMemoryRelation builds the correct buffers with complex schemas") { + testColumnBatches(useColumnBatches = true, useComplexSchema = true) + testColumnBatches(useColumnBatches = false, useComplexSchema = true) + } + + private def testColumnBatches(useColumnBatches: Boolean, useComplexSchema: Boolean = false) { + withSQLConf(SQLConf.CACHE_CODEGEN.key -> useColumnBatches.toString) { + val logicalPlan = org.apache.spark.sql.catalyst.plans.logical.Range(1, 10, 1, 10) + val sparkPlan = new org.apache.spark.sql.execution.RangeExec(logicalPlan) { + override val output: Seq[Attribute] = { + if (useComplexSchema) { + Seq(AttributeReference("complex", ArrayType(LongType))()) + } else { + logicalPlan.output + } + } + } + val inMemoryRelation = InMemoryRelation( + useCompression = false, + batchSize = 100, + storageLevel = MEMORY_ONLY, + child = sparkPlan, + tableName = None) + assert(inMemoryRelation.useColumnarBatches == useColumnBatches) + assert(inMemoryRelation.cachedColumnBuffers.getStorageLevel == MEMORY_ONLY) + inMemoryRelation.cachedColumnBuffers.collect().head match { + case _: CachedColumnarBatch => assert(useColumnBatches) + case _: CachedBatchBytes => assert(!useColumnBatches) + case other => fail(s"Unexpected cached batch type: ${other.getClass.getName}") + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 8184d7d909f4..7882e1dbba3f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -464,26 +464,30 @@ class ColumnarBatchSuite extends SparkFunSuite { column.putArray(2, 2, 0) column.putArray(3, 3, 3) - val a1 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(0)).asInstanceOf[Array[Int]] - val a2 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(1)).asInstanceOf[Array[Int]] - val a3 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(2)).asInstanceOf[Array[Int]] - val a4 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(3)).asInstanceOf[Array[Int]] + val a1 = ColumnVectorUtils.toPrimitiveJavaArray( + column.getArray(0).asInstanceOf[ColumnVector.Array]).asInstanceOf[Array[Int]] + val a2 = ColumnVectorUtils.toPrimitiveJavaArray( + column.getArray(1).asInstanceOf[ColumnVector.Array]).asInstanceOf[Array[Int]] + val a3 = ColumnVectorUtils.toPrimitiveJavaArray( + column.getArray(2).asInstanceOf[ColumnVector.Array]).asInstanceOf[Array[Int]] + val a4 = ColumnVectorUtils.toPrimitiveJavaArray( + column.getArray(3).asInstanceOf[ColumnVector.Array]).asInstanceOf[Array[Int]] assert(a1 === Array(0)) assert(a2 === Array(1, 2)) assert(a3 === Array.empty[Int]) assert(a4 === Array(3, 4, 5)) // Verify the ArrayData APIs - assert(column.getArray(0).length == 1) + assert(column.getArray(0).numElements() == 1) assert(column.getArray(0).getInt(0) == 0) - assert(column.getArray(1).length == 2) + assert(column.getArray(1).numElements() == 2) assert(column.getArray(1).getInt(0) == 1) assert(column.getArray(1).getInt(1) == 2) - assert(column.getArray(2).length == 0) + assert(column.getArray(2).numElements() == 0) - assert(column.getArray(3).length == 3) + assert(column.getArray(3).numElements() == 3) assert(column.getArray(3).getInt(0) == 3) assert(column.getArray(3).getInt(1) == 4) assert(column.getArray(3).getInt(2) == 5) @@ -496,8 +500,8 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(data.capacity == array.length * 2) data.putInts(0, array.length, array, 0) column.putArray(0, 0, array.length) - assert(ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(0)).asInstanceOf[Array[Int]] - === array) + assert(ColumnVectorUtils.toPrimitiveJavaArray( + column.getArray(0).asInstanceOf[ColumnVector.Array]).asInstanceOf[Array[Int]] === array) }} }