Skip to content

Commit 7404924

Browse files
lianchengmarmbrus
authored andcommitted
[SPARK-3294][SQL] Eliminates boxing costs from in-memory columnar storage
This is a major refactoring of the in-memory columnar storage implementation, aims to eliminate boxing costs from critical paths (building/accessing column buffers) as much as possible. The basic idea is to refactor all major interfaces into a row-based form and use them together with `SpecificMutableRow`. The difficult part is how to adapt all compression schemes, esp. `RunLengthEncoding` and `DictionaryEncoding`, to this design. Since in-memory compression is disabled by default for now, and this PR should be strictly better than before no matter in-memory compression is enabled or not, maybe I'll finish that part in another PR. **UPDATE** This PR also took the chance to optimize `HiveTableScan` by 1. leveraging `SpecificMutableRow` to avoid boxing cost, and 1. building specific `Writable` unwrapper functions a head of time to avoid per row pattern matching and branching costs. TODO - [x] Benchmark - [ ] ~~Eliminate boxing costs in `RunLengthEncoding`~~ (left to future PRs) - [ ] ~~Eliminate boxing costs in `DictionaryEncoding` (seems not easy to do without specializing `DictionaryEncoding` for every supported column type)~~ (left to future PRs) ## Micro benchmark The benchmark uses a 10 million line CSV table consists of bytes, shorts, integers, longs, floats and doubles, measures the time to build the in-memory version of this table, and the time to scan the whole in-memory table. Benchmark code can be found [here](https://gist.github.com/liancheng/fe70a148de82e77bd2c8#file-hivetablescanbenchmark-scala). Script used to generate the input table can be found [here](https://gist.github.com/liancheng/fe70a148de82e77bd2c8#file-tablegen-scala). Speedup: - Hive table scanning + column buffer building: **18.74%** The original benchmark uses 1K as in-memory batch size, when increased to 10K, it can be 28.32% faster. - In-memory table scanning: **7.95%** Before: | Building | Scanning ------- | -------- | -------- 1 | 16472 | 525 2 | 16168 | 530 3 | 16386 | 529 4 | 16184 | 538 5 | 16209 | 521 Average | 16283.8 | 528.6 After: | Building | Scanning ------- | -------- | -------- 1 | 13124 | 458 2 | 13260 | 529 3 | 12981 | 463 4 | 13214 | 483 5 | 13583 | 500 Average | 13232.4 | 486.6 Author: Cheng Lian <[email protected]> Closes #2327 from liancheng/prevent-boxing/unboxing and squashes the following commits: 4419fe4 [Cheng Lian] Addressing comments e5d2cf2 [Cheng Lian] Bug fix: should call setNullAt when field value is null to avoid NPE 8b8552b [Cheng Lian] Only checks for partition batch pruning flag once 489f97b [Cheng Lian] Bug fix: TableReader.fillObject uses wrong ordinals 97bbc4e [Cheng Lian] Optimizes hive.TableReader by by providing specific Writable unwrappers a head of time 3dc1f94 [Cheng Lian] Minor changes to eliminate row object creation 5b39cb9 [Cheng Lian] Lowers log level of compression scheme details f2a7890 [Cheng Lian] Use SpecificMutableRow in InMemoryColumnarTableScan to avoid boxing 9cf30b0 [Cheng Lian] Added row based ColumnType.append/extract 456c366 [Cheng Lian] Made compression decoder row based edac3cd [Cheng Lian] Makes ColumnAccessor.extractSingle row based 8216936 [Cheng Lian] Removes boxing cost in IntDelta and LongDelta by providing specialized implementations b70d519 [Cheng Lian] Made some in-memory columnar storage interfaces row-based
1 parent 184cd51 commit 7404924

24 files changed

+554
-292
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ final class MutableByte extends MutableValue {
171171
}
172172

173173
final class MutableAny extends MutableValue {
174-
var value: Any = 0
174+
var value: Any = _
175175
def boxed = if (isNull) null else value
176176
def update(v: Any) = value = {
177177
isNull = false

sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,13 @@ private[sql] abstract class BasicColumnAccessor[T <: DataType, JvmType](
5050

5151
def hasNext = buffer.hasRemaining
5252

53-
def extractTo(row: MutableRow, ordinal: Int) {
54-
columnType.setField(row, ordinal, extractSingle(buffer))
53+
def extractTo(row: MutableRow, ordinal: Int): Unit = {
54+
extractSingle(row, ordinal)
5555
}
5656

57-
def extractSingle(buffer: ByteBuffer): JvmType = columnType.extract(buffer)
57+
def extractSingle(row: MutableRow, ordinal: Int): Unit = {
58+
columnType.extract(buffer, row, ordinal)
59+
}
5860

5961
protected def underlyingBuffer = buffer
6062
}

sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,9 @@ private[sql] class BasicColumnBuilder[T <: DataType, JvmType](
6868
buffer.order(ByteOrder.nativeOrder()).putInt(columnType.typeId)
6969
}
7070

71-
override def appendFrom(row: Row, ordinal: Int) {
72-
val field = columnType.getField(row, ordinal)
73-
buffer = ensureFreeSpace(buffer, columnType.actualSize(field))
74-
columnType.append(field, buffer)
71+
override def appendFrom(row: Row, ordinal: Int): Unit = {
72+
buffer = ensureFreeSpace(buffer, columnType.actualSize(row, ordinal))
73+
columnType.append(row, ordinal, buffer)
7574
}
7675

7776
override def build() = {
@@ -142,16 +141,16 @@ private[sql] object ColumnBuilder {
142141
useCompression: Boolean = false): ColumnBuilder = {
143142

144143
val builder = (typeId match {
145-
case INT.typeId => new IntColumnBuilder
146-
case LONG.typeId => new LongColumnBuilder
147-
case FLOAT.typeId => new FloatColumnBuilder
148-
case DOUBLE.typeId => new DoubleColumnBuilder
149-
case BOOLEAN.typeId => new BooleanColumnBuilder
150-
case BYTE.typeId => new ByteColumnBuilder
151-
case SHORT.typeId => new ShortColumnBuilder
152-
case STRING.typeId => new StringColumnBuilder
153-
case BINARY.typeId => new BinaryColumnBuilder
154-
case GENERIC.typeId => new GenericColumnBuilder
144+
case INT.typeId => new IntColumnBuilder
145+
case LONG.typeId => new LongColumnBuilder
146+
case FLOAT.typeId => new FloatColumnBuilder
147+
case DOUBLE.typeId => new DoubleColumnBuilder
148+
case BOOLEAN.typeId => new BooleanColumnBuilder
149+
case BYTE.typeId => new ByteColumnBuilder
150+
case SHORT.typeId => new ShortColumnBuilder
151+
case STRING.typeId => new StringColumnBuilder
152+
case BINARY.typeId => new BinaryColumnBuilder
153+
case GENERIC.typeId => new GenericColumnBuilder
155154
case TIMESTAMP.typeId => new TimestampColumnBuilder
156155
}).asInstanceOf[ColumnBuilder]
157156

sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ private[sql] class ByteColumnStats extends ColumnStats {
6969
var lower = Byte.MaxValue
7070
var nullCount = 0
7171

72-
override def gatherStats(row: Row, ordinal: Int) {
72+
override def gatherStats(row: Row, ordinal: Int): Unit = {
7373
if (!row.isNullAt(ordinal)) {
7474
val value = row.getByte(ordinal)
7575
if (value > upper) upper = value
@@ -87,7 +87,7 @@ private[sql] class ShortColumnStats extends ColumnStats {
8787
var lower = Short.MaxValue
8888
var nullCount = 0
8989

90-
override def gatherStats(row: Row, ordinal: Int) {
90+
override def gatherStats(row: Row, ordinal: Int): Unit = {
9191
if (!row.isNullAt(ordinal)) {
9292
val value = row.getShort(ordinal)
9393
if (value > upper) upper = value
@@ -105,7 +105,7 @@ private[sql] class LongColumnStats extends ColumnStats {
105105
var lower = Long.MaxValue
106106
var nullCount = 0
107107

108-
override def gatherStats(row: Row, ordinal: Int) {
108+
override def gatherStats(row: Row, ordinal: Int): Unit = {
109109
if (!row.isNullAt(ordinal)) {
110110
val value = row.getLong(ordinal)
111111
if (value > upper) upper = value
@@ -123,7 +123,7 @@ private[sql] class DoubleColumnStats extends ColumnStats {
123123
var lower = Double.MaxValue
124124
var nullCount = 0
125125

126-
override def gatherStats(row: Row, ordinal: Int) {
126+
override def gatherStats(row: Row, ordinal: Int): Unit = {
127127
if (!row.isNullAt(ordinal)) {
128128
val value = row.getDouble(ordinal)
129129
if (value > upper) upper = value
@@ -141,7 +141,7 @@ private[sql] class FloatColumnStats extends ColumnStats {
141141
var lower = Float.MaxValue
142142
var nullCount = 0
143143

144-
override def gatherStats(row: Row, ordinal: Int) {
144+
override def gatherStats(row: Row, ordinal: Int): Unit = {
145145
if (!row.isNullAt(ordinal)) {
146146
val value = row.getFloat(ordinal)
147147
if (value > upper) upper = value
@@ -159,7 +159,7 @@ private[sql] class IntColumnStats extends ColumnStats {
159159
var lower = Int.MaxValue
160160
var nullCount = 0
161161

162-
override def gatherStats(row: Row, ordinal: Int) {
162+
override def gatherStats(row: Row, ordinal: Int): Unit = {
163163
if (!row.isNullAt(ordinal)) {
164164
val value = row.getInt(ordinal)
165165
if (value > upper) upper = value
@@ -177,7 +177,7 @@ private[sql] class StringColumnStats extends ColumnStats {
177177
var lower: String = null
178178
var nullCount = 0
179179

180-
override def gatherStats(row: Row, ordinal: Int) {
180+
override def gatherStats(row: Row, ordinal: Int): Unit = {
181181
if (!row.isNullAt(ordinal)) {
182182
val value = row.getString(ordinal)
183183
if (upper == null || value.compareTo(upper) > 0) upper = value
@@ -195,7 +195,7 @@ private[sql] class TimestampColumnStats extends ColumnStats {
195195
var lower: Timestamp = null
196196
var nullCount = 0
197197

198-
override def gatherStats(row: Row, ordinal: Int) {
198+
override def gatherStats(row: Row, ordinal: Int): Unit = {
199199
if (!row.isNullAt(ordinal)) {
200200
val value = row(ordinal).asInstanceOf[Timestamp]
201201
if (upper == null || value.compareTo(upper) > 0) upper = value

0 commit comments

Comments
 (0)