Skip to content

Commit 5be6b0e

Browse files
committed
[SPARK-6195] [SQL] Adds in-memory column type for fixed-precision decimals
This PR adds a specialized in-memory column type for fixed-precision decimals. For all other column types, a single integer column type ID is enough to determine which column type to use. However, this doesn't apply to fixed-precision decimal types with different precision and scale parameters. Moreover, according to the previous design, there seems no trivial way to encode precision and scale information into the columnar byte buffer. On the other hand, considering we always know the data type of the column to be built / scanned ahead of time. This PR no longer use column type ID to construct `ColumnBuilder`s and `ColumnAccessor`s, but resorts to the actual column data type. In this way, we can pass precision / scale information along the way. The column type ID is now not used anymore and can be removed in a future PR. ### Micro benchmark result The following micro benchmark builds a simple table with 2 million decimals (precision = 10, scale = 0), cache it in memory, then count all the rows. Code (simply paste it into Spark shell): ```scala import sc._ import sqlContext._ import sqlContext.implicits._ import org.apache.spark.sql.types._ import com.google.common.base.Stopwatch def benchmark(n: Int)(f: => Long) { val stopwatch = new Stopwatch() def run() = { stopwatch.reset() stopwatch.start() f stopwatch.stop() stopwatch.elapsedMillis() } val records = (0 until n).map(_ => run()) (0 until n).foreach(i => println(s"Round $i: ${records(i)} ms")) println(s"Average: ${records.sum / n.toDouble} ms") } // Explicit casting is required because ScalaReflection can't inspect decimal precision parallelize(1 to 2000000) .map(i => Tuple1(Decimal(i, 10, 0))) .toDF("dec") .select($"dec" cast DecimalType(10, 0)) .registerTempTable("dec") sql("CACHE TABLE dec") val df = table("dec") // Warm up df.count() df.count() benchmark(5) { df.count() } ``` With `FIXED_DECIMAL` column type: - Round 0: 75 ms - Round 1: 97 ms - Round 2: 75 ms - Round 3: 70 ms - Round 4: 72 ms - Average: 77.8 ms Without `FIXED_DECIMAL` column type: - Round 0: 1233 ms - Round 1: 1170 ms - Round 2: 1171 ms - Round 3: 1141 ms - Round 4: 1141 ms - Average: 1171.2 ms <!-- Reviewable:start --> [<img src="https://reviewable.io/review_button.png" height=40 alt="Review on Reviewable"/>](https://reviewable.io/reviews/apache/spark/4938) <!-- Reviewable:end --> Author: Cheng Lian <[email protected]> Closes #4938 from liancheng/decimal-column-type and squashes the following commits: fef5338 [Cheng Lian] Updates fixed decimal column type related test cases e08ab5b [Cheng Lian] Only resorts to FIXED_DECIMAL when the value can be held in a long 4db713d [Cheng Lian] Adds in-memory column type for fixed-precision decimals
1 parent ee15404 commit 5be6b0e

File tree

11 files changed

+179
-76
lines changed

11 files changed

+179
-76
lines changed

sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import java.nio.{ByteBuffer, ByteOrder}
2121

2222
import org.apache.spark.sql.catalyst.expressions.MutableRow
2323
import org.apache.spark.sql.columnar.compression.CompressibleColumnAccessor
24-
import org.apache.spark.sql.types.{BinaryType, DataType, NativeType}
24+
import org.apache.spark.sql.types._
2525

2626
/**
2727
* An `Iterator` like trait used to extract values from columnar byte buffer. When a value is
@@ -89,6 +89,9 @@ private[sql] class DoubleColumnAccessor(buffer: ByteBuffer)
8989
private[sql] class FloatColumnAccessor(buffer: ByteBuffer)
9090
extends NativeColumnAccessor(buffer, FLOAT)
9191

92+
private[sql] class FixedDecimalColumnAccessor(buffer: ByteBuffer, precision: Int, scale: Int)
93+
extends NativeColumnAccessor(buffer, FIXED_DECIMAL(precision, scale))
94+
9295
private[sql] class StringColumnAccessor(buffer: ByteBuffer)
9396
extends NativeColumnAccessor(buffer, STRING)
9497

@@ -107,24 +110,28 @@ private[sql] class GenericColumnAccessor(buffer: ByteBuffer)
107110
with NullableColumnAccessor
108111

109112
private[sql] object ColumnAccessor {
110-
def apply(buffer: ByteBuffer): ColumnAccessor = {
113+
def apply(dataType: DataType, buffer: ByteBuffer): ColumnAccessor = {
111114
val dup = buffer.duplicate().order(ByteOrder.nativeOrder)
112-
// The first 4 bytes in the buffer indicate the column type.
113-
val columnTypeId = dup.getInt()
114-
115-
columnTypeId match {
116-
case INT.typeId => new IntColumnAccessor(dup)
117-
case LONG.typeId => new LongColumnAccessor(dup)
118-
case FLOAT.typeId => new FloatColumnAccessor(dup)
119-
case DOUBLE.typeId => new DoubleColumnAccessor(dup)
120-
case BOOLEAN.typeId => new BooleanColumnAccessor(dup)
121-
case BYTE.typeId => new ByteColumnAccessor(dup)
122-
case SHORT.typeId => new ShortColumnAccessor(dup)
123-
case STRING.typeId => new StringColumnAccessor(dup)
124-
case DATE.typeId => new DateColumnAccessor(dup)
125-
case TIMESTAMP.typeId => new TimestampColumnAccessor(dup)
126-
case BINARY.typeId => new BinaryColumnAccessor(dup)
127-
case GENERIC.typeId => new GenericColumnAccessor(dup)
115+
116+
// The first 4 bytes in the buffer indicate the column type. This field is not used now,
117+
// because we always know the data type of the column ahead of time.
118+
dup.getInt()
119+
120+
dataType match {
121+
case IntegerType => new IntColumnAccessor(dup)
122+
case LongType => new LongColumnAccessor(dup)
123+
case FloatType => new FloatColumnAccessor(dup)
124+
case DoubleType => new DoubleColumnAccessor(dup)
125+
case BooleanType => new BooleanColumnAccessor(dup)
126+
case ByteType => new ByteColumnAccessor(dup)
127+
case ShortType => new ShortColumnAccessor(dup)
128+
case StringType => new StringColumnAccessor(dup)
129+
case BinaryType => new BinaryColumnAccessor(dup)
130+
case DateType => new DateColumnAccessor(dup)
131+
case TimestampType => new TimestampColumnAccessor(dup)
132+
case DecimalType.Fixed(precision, scale) if precision < 19 =>
133+
new FixedDecimalColumnAccessor(dup, precision, scale)
134+
case _ => new GenericColumnAccessor(dup)
128135
}
129136
}
130137
}

sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,13 @@ private[sql] class DoubleColumnBuilder extends NativeColumnBuilder(new DoubleCol
106106

107107
private[sql] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT)
108108

109+
private[sql] class FixedDecimalColumnBuilder(
110+
precision: Int,
111+
scale: Int)
112+
extends NativeColumnBuilder(
113+
new FixedDecimalColumnStats,
114+
FIXED_DECIMAL(precision, scale))
115+
109116
private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING)
110117

111118
private[sql] class DateColumnBuilder extends NativeColumnBuilder(new DateColumnStats, DATE)
@@ -139,25 +146,25 @@ private[sql] object ColumnBuilder {
139146
}
140147

141148
def apply(
142-
typeId: Int,
149+
dataType: DataType,
143150
initialSize: Int = 0,
144151
columnName: String = "",
145152
useCompression: Boolean = false): ColumnBuilder = {
146-
147-
val builder = (typeId match {
148-
case INT.typeId => new IntColumnBuilder
149-
case LONG.typeId => new LongColumnBuilder
150-
case FLOAT.typeId => new FloatColumnBuilder
151-
case DOUBLE.typeId => new DoubleColumnBuilder
152-
case BOOLEAN.typeId => new BooleanColumnBuilder
153-
case BYTE.typeId => new ByteColumnBuilder
154-
case SHORT.typeId => new ShortColumnBuilder
155-
case STRING.typeId => new StringColumnBuilder
156-
case BINARY.typeId => new BinaryColumnBuilder
157-
case GENERIC.typeId => new GenericColumnBuilder
158-
case DATE.typeId => new DateColumnBuilder
159-
case TIMESTAMP.typeId => new TimestampColumnBuilder
160-
}).asInstanceOf[ColumnBuilder]
153+
val builder: ColumnBuilder = dataType match {
154+
case IntegerType => new IntColumnBuilder
155+
case LongType => new LongColumnBuilder
156+
case DoubleType => new DoubleColumnBuilder
157+
case BooleanType => new BooleanColumnBuilder
158+
case ByteType => new ByteColumnBuilder
159+
case ShortType => new ShortColumnBuilder
160+
case StringType => new StringColumnBuilder
161+
case BinaryType => new BinaryColumnBuilder
162+
case DateType => new DateColumnBuilder
163+
case TimestampType => new TimestampColumnBuilder
164+
case DecimalType.Fixed(precision, scale) if precision < 19 =>
165+
new FixedDecimalColumnBuilder(precision, scale)
166+
case _ => new GenericColumnBuilder
167+
}
161168

162169
builder.initialize(initialSize, columnName, useCompression)
163170
builder

sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,23 @@ private[sql] class FloatColumnStats extends ColumnStats {
181181
def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
182182
}
183183

184+
private[sql] class FixedDecimalColumnStats extends ColumnStats {
185+
protected var upper: Decimal = null
186+
protected var lower: Decimal = null
187+
188+
override def gatherStats(row: Row, ordinal: Int): Unit = {
189+
super.gatherStats(row, ordinal)
190+
if (!row.isNullAt(ordinal)) {
191+
val value = row(ordinal).asInstanceOf[Decimal]
192+
if (upper == null || value.compareTo(upper) > 0) upper = value
193+
if (lower == null || value.compareTo(lower) < 0) lower = value
194+
sizeInBytes += FIXED_DECIMAL.defaultSize
195+
}
196+
}
197+
198+
override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes)
199+
}
200+
184201
private[sql] class IntColumnStats extends ColumnStats {
185202
protected var upper = Int.MinValue
186203
protected var lower = Int.MaxValue

sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,33 @@ private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 12) {
373373
}
374374
}
375375

376+
private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int)
377+
extends NativeColumnType(
378+
DecimalType(Some(PrecisionInfo(precision, scale))),
379+
10,
380+
FIXED_DECIMAL.defaultSize) {
381+
382+
override def extract(buffer: ByteBuffer): Decimal = {
383+
Decimal(buffer.getLong(), precision, scale)
384+
}
385+
386+
override def append(v: Decimal, buffer: ByteBuffer): Unit = {
387+
buffer.putLong(v.toUnscaledLong)
388+
}
389+
390+
override def getField(row: Row, ordinal: Int): Decimal = {
391+
row(ordinal).asInstanceOf[Decimal]
392+
}
393+
394+
override def setField(row: MutableRow, ordinal: Int, value: Decimal): Unit = {
395+
row(ordinal) = value
396+
}
397+
}
398+
399+
private[sql] object FIXED_DECIMAL {
400+
val defaultSize = 8
401+
}
402+
376403
private[sql] sealed abstract class ByteArrayColumnType[T <: DataType](
377404
typeId: Int,
378405
defaultSize: Int)
@@ -394,7 +421,7 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType](
394421
}
395422
}
396423

397-
private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](10, 16) {
424+
private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](11, 16) {
398425
override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = {
399426
row(ordinal) = value
400427
}
@@ -405,7 +432,7 @@ private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](10, 16)
405432
// Used to process generic objects (all types other than those listed above). Objects should be
406433
// serialized first before appending to the column `ByteBuffer`, and is also extracted as serialized
407434
// byte array.
408-
private[sql] object GENERIC extends ByteArrayColumnType[DataType](11, 16) {
435+
private[sql] object GENERIC extends ByteArrayColumnType[DataType](12, 16) {
409436
override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = {
410437
row(ordinal) = SparkSqlSerializer.deserialize[Any](value)
411438
}
@@ -416,18 +443,20 @@ private[sql] object GENERIC extends ByteArrayColumnType[DataType](11, 16) {
416443
private[sql] object ColumnType {
417444
def apply(dataType: DataType): ColumnType[_, _] = {
418445
dataType match {
419-
case IntegerType => INT
420-
case LongType => LONG
421-
case FloatType => FLOAT
422-
case DoubleType => DOUBLE
423-
case BooleanType => BOOLEAN
424-
case ByteType => BYTE
425-
case ShortType => SHORT
426-
case StringType => STRING
427-
case BinaryType => BINARY
428-
case DateType => DATE
446+
case IntegerType => INT
447+
case LongType => LONG
448+
case FloatType => FLOAT
449+
case DoubleType => DOUBLE
450+
case BooleanType => BOOLEAN
451+
case ByteType => BYTE
452+
case ShortType => SHORT
453+
case StringType => STRING
454+
case BinaryType => BINARY
455+
case DateType => DATE
429456
case TimestampType => TIMESTAMP
430-
case _ => GENERIC
457+
case DecimalType.Fixed(precision, scale) if precision < 19 =>
458+
FIXED_DECIMAL(precision, scale)
459+
case _ => GENERIC
431460
}
432461
}
433462
}

sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ private[sql] case class InMemoryRelation(
113113
val columnBuilders = output.map { attribute =>
114114
val columnType = ColumnType(attribute.dataType)
115115
val initialBufferSize = columnType.defaultSize * batchSize
116-
ColumnBuilder(columnType.typeId, initialBufferSize, attribute.name, useCompression)
116+
ColumnBuilder(attribute.dataType, initialBufferSize, attribute.name, useCompression)
117117
}.toArray
118118

119119
var rowCount = 0
@@ -274,8 +274,10 @@ private[sql] case class InMemoryColumnarTableScan(
274274
def cachedBatchesToRows(cacheBatches: Iterator[CachedBatch]) = {
275275
val rows = cacheBatches.flatMap { cachedBatch =>
276276
// Build column accessors
277-
val columnAccessors = requestedColumnIndices.map { batch =>
278-
ColumnAccessor(ByteBuffer.wrap(cachedBatch.buffers(batch)))
277+
val columnAccessors = requestedColumnIndices.map { batchColumnIndex =>
278+
ColumnAccessor(
279+
relation.output(batchColumnIndex).dataType,
280+
ByteBuffer.wrap(cachedBatch.buffers(batchColumnIndex)))
279281
}
280282

281283
// Extract rows via column accessors

sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class ColumnStatsSuite extends FunSuite {
2929
testColumnStats(classOf[LongColumnStats], LONG, Row(Long.MaxValue, Long.MinValue, 0))
3030
testColumnStats(classOf[FloatColumnStats], FLOAT, Row(Float.MaxValue, Float.MinValue, 0))
3131
testColumnStats(classOf[DoubleColumnStats], DOUBLE, Row(Double.MaxValue, Double.MinValue, 0))
32+
testColumnStats(classOf[FixedDecimalColumnStats], FIXED_DECIMAL(15, 10), Row(null, null, 0))
3233
testColumnStats(classOf[StringColumnStats], STRING, Row(null, null, 0))
3334
testColumnStats(classOf[DateColumnStats], DATE, Row(Int.MaxValue, Int.MinValue, 0))
3435
testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, Row(null, null, 0))

sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@ class ColumnTypeSuite extends FunSuite with Logging {
3333

3434
test("defaultSize") {
3535
val checks = Map(
36-
INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4, BOOLEAN -> 1,
37-
STRING -> 8, DATE -> 4, TIMESTAMP -> 12, BINARY -> 16, GENERIC -> 16)
36+
INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4,
37+
FIXED_DECIMAL(15, 10) -> 8, BOOLEAN -> 1, STRING -> 8, DATE -> 4, TIMESTAMP -> 12,
38+
BINARY -> 16, GENERIC -> 16)
3839

3940
checks.foreach { case (columnType, expectedSize) =>
4041
assertResult(expectedSize, s"Wrong defaultSize for $columnType") {
@@ -56,15 +57,16 @@ class ColumnTypeSuite extends FunSuite with Logging {
5657
}
5758
}
5859

59-
checkActualSize(INT, Int.MaxValue, 4)
60-
checkActualSize(SHORT, Short.MaxValue, 2)
61-
checkActualSize(LONG, Long.MaxValue, 8)
62-
checkActualSize(BYTE, Byte.MaxValue, 1)
63-
checkActualSize(DOUBLE, Double.MaxValue, 8)
64-
checkActualSize(FLOAT, Float.MaxValue, 4)
65-
checkActualSize(BOOLEAN, true, 1)
66-
checkActualSize(STRING, "hello", 4 + "hello".getBytes("utf-8").length)
67-
checkActualSize(DATE, 0, 4)
60+
checkActualSize(INT, Int.MaxValue, 4)
61+
checkActualSize(SHORT, Short.MaxValue, 2)
62+
checkActualSize(LONG, Long.MaxValue, 8)
63+
checkActualSize(BYTE, Byte.MaxValue, 1)
64+
checkActualSize(DOUBLE, Double.MaxValue, 8)
65+
checkActualSize(FLOAT, Float.MaxValue, 4)
66+
checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8)
67+
checkActualSize(BOOLEAN, true, 1)
68+
checkActualSize(STRING, "hello", 4 + "hello".getBytes("utf-8").length)
69+
checkActualSize(DATE, 0, 4)
6870
checkActualSize(TIMESTAMP, new Timestamp(0L), 12)
6971

7072
val binary = Array.fill[Byte](4)(0: Byte)
@@ -93,12 +95,20 @@ class ColumnTypeSuite extends FunSuite with Logging {
9395

9496
testNativeColumnType[DoubleType.type](DOUBLE, _.putDouble(_), _.getDouble)
9597

98+
testNativeColumnType[DecimalType](
99+
FIXED_DECIMAL(15, 10),
100+
(buffer: ByteBuffer, decimal: Decimal) => {
101+
buffer.putLong(decimal.toUnscaledLong)
102+
},
103+
(buffer: ByteBuffer) => {
104+
Decimal(buffer.getLong(), 15, 10)
105+
})
106+
96107
testNativeColumnType[FloatType.type](FLOAT, _.putFloat(_), _.getFloat)
97108

98109
testNativeColumnType[StringType.type](
99110
STRING,
100111
(buffer: ByteBuffer, string: String) => {
101-
102112
val bytes = string.getBytes("utf-8")
103113
buffer.putInt(bytes.length)
104114
buffer.put(bytes)
@@ -206,4 +216,16 @@ class ColumnTypeSuite extends FunSuite with Logging {
206216
if (sb.nonEmpty) sb.setLength(sb.length - 1)
207217
sb.toString()
208218
}
219+
220+
test("column type for decimal types with different precision") {
221+
(1 to 18).foreach { i =>
222+
assertResult(FIXED_DECIMAL(i, 0)) {
223+
ColumnType(DecimalType(i, 0))
224+
}
225+
}
226+
227+
assertResult(GENERIC) {
228+
ColumnType(DecimalType(19, 0))
229+
}
230+
}
209231
}

sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import scala.util.Random
2424

2525
import org.apache.spark.sql.Row
2626
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
27-
import org.apache.spark.sql.types.{DataType, NativeType}
27+
import org.apache.spark.sql.types.{Decimal, DataType, NativeType}
2828

2929
object ColumnarTestUtils {
3030
def makeNullRow(length: Int) = {
@@ -41,16 +41,17 @@ object ColumnarTestUtils {
4141
}
4242

4343
(columnType match {
44-
case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte
45-
case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort
46-
case INT => Random.nextInt()
47-
case LONG => Random.nextLong()
48-
case FLOAT => Random.nextFloat()
49-
case DOUBLE => Random.nextDouble()
50-
case STRING => Random.nextString(Random.nextInt(32))
51-
case BOOLEAN => Random.nextBoolean()
52-
case BINARY => randomBytes(Random.nextInt(32))
53-
case DATE => Random.nextInt()
44+
case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte
45+
case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort
46+
case INT => Random.nextInt()
47+
case LONG => Random.nextLong()
48+
case FLOAT => Random.nextFloat()
49+
case DOUBLE => Random.nextDouble()
50+
case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale)
51+
case STRING => Random.nextString(Random.nextInt(32))
52+
case BOOLEAN => Random.nextBoolean()
53+
case BINARY => randomBytes(Random.nextInt(32))
54+
case DATE => Random.nextInt()
5455
case TIMESTAMP =>
5556
val timestamp = new Timestamp(Random.nextLong())
5657
timestamp.setNanos(Random.nextInt(999999999))

0 commit comments

Comments
 (0)