Skip to content

Commit b70d519

Browse files
committed
Made some in-memory columnar storage interfaces row-based
1 parent 25b5b86 commit b70d519

File tree

8 files changed

+73
-56
lines changed

8 files changed

+73
-56
lines changed

sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,8 @@ private[sql] class BasicColumnBuilder[T <: DataType, JvmType](
6969
}
7070

7171
override def appendFrom(row: Row, ordinal: Int) {
72-
val field = columnType.getField(row, ordinal)
73-
buffer = ensureFreeSpace(buffer, columnType.actualSize(field))
74-
columnType.append(field, buffer)
72+
buffer = ensureFreeSpace(buffer, columnType.actualSize(row, ordinal))
73+
columnType.append(columnType.getField(row, ordinal), buffer)
7574
}
7675

7776
override def build() = {

sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,12 @@
1818
package org.apache.spark.sql.columnar
1919

2020
import java.nio.ByteBuffer
21+
import java.sql.Timestamp
2122

2223
import scala.reflect.runtime.universe.TypeTag
2324

24-
import java.sql.Timestamp
25-
2625
import org.apache.spark.sql.Row
27-
import org.apache.spark.sql.catalyst.expressions.MutableRow
26+
import org.apache.spark.sql.catalyst.expressions.{MutableAny, MutableRow, MutableValue}
2827
import org.apache.spark.sql.catalyst.types._
2928
import org.apache.spark.sql.execution.SparkSqlSerializer
3029

@@ -41,6 +40,8 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType](
4140
val typeId: Int,
4241
val defaultSize: Int) {
4342

43+
val mutable: MutableValue = new MutableAny
44+
4445
/**
4546
* Extracts a value out of the buffer at the buffer's current position.
4647
*/
@@ -52,10 +53,10 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType](
5253
def append(v: JvmType, buffer: ByteBuffer)
5354

5455
/**
55-
* Returns the size of the value. This is used to calculate the size of variable length types
56-
* such as byte arrays and strings.
56+
* Returns the size of the value `row(ordinal)`. This is used to calculate the size of variable
57+
* length types such as byte arrays and strings.
5758
*/
58-
def actualSize(v: JvmType): Int = defaultSize
59+
def actualSize(row: Row, ordinal: Int): Int = defaultSize
5960

6061
/**
6162
* Returns `row(ordinal)`. Subclasses should override this method to avoid boxing/unboxing costs
@@ -200,7 +201,9 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) {
200201
}
201202

202203
private[sql] object STRING extends NativeColumnType(StringType, 7, 8) {
203-
override def actualSize(v: String): Int = v.getBytes("utf-8").length + 4
204+
override def actualSize(row: Row, ordinal: Int): Int = {
205+
row.getString(ordinal).getBytes("utf-8").length + 4
206+
}
204207

205208
override def append(v: String, buffer: ByteBuffer) {
206209
val stringBytes = v.getBytes("utf-8")
@@ -246,7 +249,9 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType](
246249
defaultSize: Int)
247250
extends ColumnType[T, Array[Byte]](typeId, defaultSize) {
248251

249-
override def actualSize(v: Array[Byte]) = v.length + 4
252+
override def actualSize(row: Row, ordinal: Int) = {
253+
getField(row, ordinal).length + 4
254+
}
250255

251256
override def append(v: Array[Byte], buffer: ByteBuffer) {
252257
buffer.putInt(v.length).put(v, 0, v.length)

sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType]
5151
abstract override def initialize(initialSize: Int, columnName: String, useCompression: Boolean) {
5252
compressionEncoders =
5353
if (useCompression) {
54-
schemes.filter(_.supports(columnType)).map(_.encoder[T])
54+
schemes.filter(_.supports(columnType)).map(_.encoder[T](columnType))
5555
} else {
56-
Seq(PassThrough.encoder)
56+
Seq(PassThrough.encoder(columnType))
5757
}
5858
super.initialize(initialSize, columnName, useCompression)
5959
}
@@ -63,11 +63,9 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType]
6363
}
6464

6565
private def gatherCompressibilityStats(row: Row, ordinal: Int) {
66-
val field = columnType.getField(row, ordinal)
67-
6866
var i = 0
6967
while (i < compressionEncoders.length) {
70-
compressionEncoders(i).gatherCompressibilityStats(field, columnType)
68+
compressionEncoders(i).gatherCompressibilityStats(row, ordinal)
7169
i += 1
7270
}
7371
}
@@ -84,7 +82,7 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType]
8482
val typeId = nonNullBuffer.getInt()
8583
val encoder: Encoder[T] = {
8684
val candidate = compressionEncoders.minBy(_.compressionRatio)
87-
if (isWorthCompressing(candidate)) candidate else PassThrough.encoder
85+
if (isWorthCompressing(candidate)) candidate else PassThrough.encoder(columnType)
8886
}
8987

9088
// Header = column type ID + null count + null positions
@@ -105,6 +103,6 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType]
105103
.put(nulls)
106104

107105
logInfo(s"Compressor for [$columnName]: $encoder, ratio: ${encoder.compressionRatio}")
108-
encoder.compress(nonNullBuffer, compressedBuffer, columnType)
106+
encoder.compress(nonNullBuffer, compressedBuffer)
109107
}
110108
}

sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@
1717

1818
package org.apache.spark.sql.columnar.compression
1919

20-
import java.nio.{ByteOrder, ByteBuffer}
20+
import java.nio.{ByteBuffer, ByteOrder}
2121

22+
import org.apache.spark.sql.Row
2223
import org.apache.spark.sql.catalyst.types.NativeType
2324
import org.apache.spark.sql.columnar.{ColumnType, NativeColumnType}
2425

2526
private[sql] trait Encoder[T <: NativeType] {
26-
def gatherCompressibilityStats(value: T#JvmType, columnType: NativeColumnType[T]) {}
27+
def gatherCompressibilityStats(row: Row, ordinal: Int) {}
2728

2829
def compressedSize: Int
2930

@@ -33,7 +34,7 @@ private[sql] trait Encoder[T <: NativeType] {
3334
if (uncompressedSize > 0) compressedSize.toDouble / uncompressedSize else 1.0
3435
}
3536

36-
def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[T]): ByteBuffer
37+
def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer
3738
}
3839

3940
private[sql] trait Decoder[T <: NativeType] extends Iterator[T#JvmType]
@@ -43,7 +44,7 @@ private[sql] trait CompressionScheme {
4344

4445
def supports(columnType: ColumnType[_, _]): Boolean
4546

46-
def encoder[T <: NativeType]: Encoder[T]
47+
def encoder[T <: NativeType](columnType: NativeColumnType[T]): Encoder[T]
4748

4849
def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T]
4950
}

sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ import scala.collection.mutable
2323
import scala.reflect.ClassTag
2424
import scala.reflect.runtime.universe.runtimeMirror
2525

26-
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
26+
import org.apache.spark.sql.Row
27+
import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
2728
import org.apache.spark.sql.catalyst.types._
2829
import org.apache.spark.sql.columnar._
2930
import org.apache.spark.util.Utils
@@ -33,18 +34,20 @@ private[sql] case object PassThrough extends CompressionScheme {
3334

3435
override def supports(columnType: ColumnType[_, _]) = true
3536

36-
override def encoder[T <: NativeType] = new this.Encoder[T]
37+
override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = {
38+
new this.Encoder[T](columnType)
39+
}
3740

3841
override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = {
3942
new this.Decoder(buffer, columnType)
4043
}
4144

42-
class Encoder[T <: NativeType] extends compression.Encoder[T] {
45+
class Encoder[T <: NativeType](columnType: NativeColumnType[T]) extends compression.Encoder[T] {
4346
override def uncompressedSize = 0
4447

4548
override def compressedSize = 0
4649

47-
override def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[T]) = {
50+
override def compress(from: ByteBuffer, to: ByteBuffer) = {
4851
// Writes compression type ID and copies raw contents
4952
to.putInt(PassThrough.typeId).put(from).rewind()
5053
to
@@ -63,7 +66,9 @@ private[sql] case object PassThrough extends CompressionScheme {
6366
private[sql] case object RunLengthEncoding extends CompressionScheme {
6467
override val typeId = 1
6568

66-
override def encoder[T <: NativeType] = new this.Encoder[T]
69+
override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = {
70+
new this.Encoder[T](columnType)
71+
}
6772

6873
override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = {
6974
new this.Decoder(buffer, columnType)
@@ -74,20 +79,21 @@ private[sql] case object RunLengthEncoding extends CompressionScheme {
7479
case _ => false
7580
}
7681

77-
class Encoder[T <: NativeType] extends compression.Encoder[T] {
82+
class Encoder[T <: NativeType](columnType: NativeColumnType[T]) extends compression.Encoder[T] {
7883
private var _uncompressedSize = 0
7984
private var _compressedSize = 0
8085

8186
// Using `MutableRow` to store the last value to avoid boxing/unboxing cost.
82-
private val lastValue = new GenericMutableRow(1)
87+
private val lastValue = new SpecificMutableRow(Seq(columnType.dataType))
8388
private var lastRun = 0
8489

8590
override def uncompressedSize = _uncompressedSize
8691

8792
override def compressedSize = _compressedSize
8893

89-
override def gatherCompressibilityStats(value: T#JvmType, columnType: NativeColumnType[T]) {
90-
val actualSize = columnType.actualSize(value)
94+
override def gatherCompressibilityStats(row: Row, ordinal: Int) {
95+
val value = columnType.getField(row, ordinal)
96+
val actualSize = columnType.actualSize(row, ordinal)
9197
_uncompressedSize += actualSize
9298

9399
if (lastValue.isNullAt(0)) {
@@ -105,7 +111,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme {
105111
}
106112
}
107113

108-
override def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[T]) = {
114+
override def compress(from: ByteBuffer, to: ByteBuffer) = {
109115
to.putInt(RunLengthEncoding.typeId)
110116

111117
if (from.hasRemaining) {
@@ -171,14 +177,16 @@ private[sql] case object DictionaryEncoding extends CompressionScheme {
171177
new this.Decoder(buffer, columnType)
172178
}
173179

174-
override def encoder[T <: NativeType] = new this.Encoder[T]
180+
override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = {
181+
new this.Encoder[T](columnType)
182+
}
175183

176184
override def supports(columnType: ColumnType[_, _]) = columnType match {
177185
case INT | LONG | STRING => true
178186
case _ => false
179187
}
180188

181-
class Encoder[T <: NativeType] extends compression.Encoder[T] {
189+
class Encoder[T <: NativeType](columnType: NativeColumnType[T]) extends compression.Encoder[T] {
182190
// Size of the input, uncompressed, in bytes. Note that we only count until the dictionary
183191
// overflows.
184192
private var _uncompressedSize = 0
@@ -200,9 +208,11 @@ private[sql] case object DictionaryEncoding extends CompressionScheme {
200208
// to store dictionary element count.
201209
private var dictionarySize = 4
202210

203-
override def gatherCompressibilityStats(value: T#JvmType, columnType: NativeColumnType[T]) {
211+
override def gatherCompressibilityStats(row: Row, ordinal: Int) {
212+
val value = columnType.getField(row, ordinal)
213+
204214
if (!overflow) {
205-
val actualSize = columnType.actualSize(value)
215+
val actualSize = columnType.actualSize(row, ordinal)
206216
count += 1
207217
_uncompressedSize += actualSize
208218

@@ -221,7 +231,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme {
221231
}
222232
}
223233

224-
override def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[T]) = {
234+
override def compress(from: ByteBuffer, to: ByteBuffer) = {
225235
if (overflow) {
226236
throw new IllegalStateException(
227237
"Dictionary encoding should not be used because of dictionary overflow.")
@@ -279,25 +289,20 @@ private[sql] case object BooleanBitSet extends CompressionScheme {
279289
new this.Decoder(buffer).asInstanceOf[compression.Decoder[T]]
280290
}
281291

282-
override def encoder[T <: NativeType] = (new this.Encoder).asInstanceOf[compression.Encoder[T]]
292+
override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = {
293+
(new this.Encoder).asInstanceOf[compression.Encoder[T]]
294+
}
283295

284296
override def supports(columnType: ColumnType[_, _]) = columnType == BOOLEAN
285297

286298
class Encoder extends compression.Encoder[BooleanType.type] {
287299
private var _uncompressedSize = 0
288300

289-
override def gatherCompressibilityStats(
290-
value: Boolean,
291-
columnType: NativeColumnType[BooleanType.type]) {
292-
301+
override def gatherCompressibilityStats(row: Row, ordinal: Int) {
293302
_uncompressedSize += BOOLEAN.defaultSize
294303
}
295304

296-
override def compress(
297-
from: ByteBuffer,
298-
to: ByteBuffer,
299-
columnType: NativeColumnType[BooleanType.type]) = {
300-
305+
override def compress(from: ByteBuffer, to: ByteBuffer) = {
301306
to.putInt(BooleanBitSet.typeId)
302307
// Total element count (1 byte per Boolean value)
303308
.putInt(from.remaining)
@@ -364,13 +369,18 @@ private[sql] case object BooleanBitSet extends CompressionScheme {
364369
}
365370
}
366371

367-
private[sql] sealed abstract class IntegralDelta[I <: IntegralType] extends CompressionScheme {
372+
private[sql] sealed abstract class IntegralDelta[I <: IntegralType](
373+
columnType: NativeColumnType[I])
374+
extends CompressionScheme {
375+
368376
override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = {
369377
new this.Decoder(buffer, columnType.asInstanceOf[NativeColumnType[I]])
370378
.asInstanceOf[compression.Decoder[T]]
371379
}
372380

373-
override def encoder[T <: NativeType] = (new this.Encoder).asInstanceOf[compression.Encoder[T]]
381+
override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = {
382+
(new this.Encoder).asInstanceOf[compression.Encoder[T]]
383+
}
374384

375385
/**
376386
* Computes `delta = x - y`, returns `(true, delta)` if `delta` can fit into a single byte, or
@@ -392,7 +402,8 @@ private[sql] sealed abstract class IntegralDelta[I <: IntegralType] extends Comp
392402

393403
private var initial = true
394404

395-
override def gatherCompressibilityStats(value: I#JvmType, columnType: NativeColumnType[I]) {
405+
override def gatherCompressibilityStats(row: Row, ordinal: Int) {
406+
val value = columnType.getField(row, ordinal)
396407
_uncompressedSize += columnType.defaultSize
397408

398409
if (initial) {
@@ -406,7 +417,7 @@ private[sql] sealed abstract class IntegralDelta[I <: IntegralType] extends Comp
406417
prev = value
407418
}
408419

409-
override def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[I]) = {
420+
override def compress(from: ByteBuffer, to: ByteBuffer) = {
410421
to.putInt(typeId)
411422

412423
if (from.hasRemaining) {
@@ -452,7 +463,7 @@ private[sql] sealed abstract class IntegralDelta[I <: IntegralType] extends Comp
452463
}
453464
}
454465

455-
private[sql] case object IntDelta extends IntegralDelta[IntegerType.type] {
466+
private[sql] case object IntDelta extends IntegralDelta[IntegerType.type](INT) {
456467
override val typeId = 4
457468

458469
override def supports(columnType: ColumnType[_, _]) = columnType == INT
@@ -465,7 +476,7 @@ private[sql] case object IntDelta extends IntegralDelta[IntegerType.type] {
465476
}
466477
}
467478

468-
private[sql] case object LongDelta extends IntegralDelta[LongType.type] {
479+
private[sql] case object LongDelta extends IntegralDelta[LongType.type](LONG) {
469480
override val typeId = 5
470481

471482
override def supports(columnType: ColumnType[_, _]) = columnType == LONG

sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import java.sql.Timestamp
2323
import org.scalatest.FunSuite
2424

2525
import org.apache.spark.Logging
26+
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Row}
2627
import org.apache.spark.sql.catalyst.types._
2728
import org.apache.spark.sql.columnar.ColumnarTestUtils._
2829
import org.apache.spark.sql.execution.SparkSqlSerializer
@@ -49,7 +50,9 @@ class ColumnTypeSuite extends FunSuite with Logging {
4950
expected: Int) {
5051

5152
assertResult(expected, s"Wrong actualSize for $columnType") {
52-
columnType.actualSize(value)
53+
val row = new GenericMutableRow(1)
54+
columnType.setField(row, 0, value)
55+
columnType.actualSize(row, 0)
5356
}
5457
}
5558

sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ class DictionaryEncodingSuite extends FunSuite {
6767
val buffer = builder.build()
6868
val headerSize = CompressionScheme.columnHeaderSize(buffer)
6969
// 4 extra bytes for dictionary size
70-
val dictionarySize = 4 + values.map(columnType.actualSize).sum
70+
val dictionarySize = 4 + rows.map(columnType.actualSize(_, 0)).sum
7171
// 2 bytes for each `Short`
7272
val compressedSize = 4 + dictionarySize + 2 * inputSeq.length
7373
// 4 extra bytes for compression scheme type ID

sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class RunLengthEncodingSuite extends FunSuite {
5757
// Compression scheme ID + compressed contents
5858
val compressedSize = 4 + inputRuns.map { case (index, _) =>
5959
// 4 extra bytes each run for run length
60-
columnType.actualSize(values(index)) + 4
60+
columnType.actualSize(rows(index), 0) + 4
6161
}.sum
6262

6363
// 4 extra bytes for compression scheme type ID

0 commit comments

Comments
 (0)