Skip to content

Commit 9cf30b0

Browse files
committed
Added row based ColumnType.append/extract
1 parent 456c366 commit 9cf30b0

File tree

4 files changed

+87
-11
lines changed

4 files changed

+87
-11
lines changed

sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ private[sql] abstract class BasicColumnAccessor[T <: DataType, JvmType](
5555
}
5656

5757
def extractSingle(row: MutableRow, ordinal: Int) {
58-
columnType.setField(row, ordinal, columnType.extract(buffer))
58+
columnType.extract(buffer, row, ordinal)
5959
}
6060

6161
protected def underlyingBuffer = buffer

sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ private[sql] class BasicColumnBuilder[T <: DataType, JvmType](
7070

7171
override def appendFrom(row: Row, ordinal: Int) {
7272
buffer = ensureFreeSpace(buffer, columnType.actualSize(row, ordinal))
73-
columnType.append(columnType.getField(row, ordinal), buffer)
73+
columnType.append(row, ordinal, buffer)
7474
}
7575

7676
override def build() = {

sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import java.sql.Timestamp
2323
import scala.reflect.runtime.universe.TypeTag
2424

2525
import org.apache.spark.sql.Row
26-
import org.apache.spark.sql.catalyst.expressions.{MutableAny, MutableRow, MutableValue}
26+
import org.apache.spark.sql.catalyst.expressions.MutableRow
2727
import org.apache.spark.sql.catalyst.types._
2828
import org.apache.spark.sql.execution.SparkSqlSerializer
2929

@@ -45,11 +45,28 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType](
4545
*/
4646
def extract(buffer: ByteBuffer): JvmType
4747

48+
/**
49+
* Extracts a value out of the buffer at the buffer's current position and stores in
50+
* `row(ordinal)`. Subclasses should override this method to avoid boxing/unboxing costs whenever
51+
* possible.
52+
*/
53+
def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int) {
54+
setField(row, ordinal, extract(buffer))
55+
}
56+
4857
/**
4958
* Appends the given value v of type T into the given ByteBuffer.
5059
*/
5160
def append(v: JvmType, buffer: ByteBuffer)
5261

62+
/**
63+
* Appends `row(ordinal)` of type T into the given ByteBuffer. Subclasses should override this
64+
* method to avoid boxing/unboxing costs whenever possible.
65+
*/
66+
def append(row: Row, ordinal: Int, buffer: ByteBuffer) {
67+
append(getField(row, ordinal), buffer)
68+
}
69+
5370
/**
5471
* Returns the size of the value `row(ordinal)`. This is used to calculate the size of variable
5572
* length types such as byte arrays and strings.
@@ -101,10 +118,18 @@ private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) {
101118
buffer.putInt(v)
102119
}
103120

121+
override def append(row: Row, ordinal: Int, buffer: ByteBuffer) {
122+
buffer.putInt(row.getInt(ordinal))
123+
}
124+
104125
def extract(buffer: ByteBuffer) = {
105126
buffer.getInt()
106127
}
107128

129+
override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int) {
130+
row.setInt(ordinal, buffer.getInt())
131+
}
132+
108133
override def setField(row: MutableRow, ordinal: Int, value: Int) {
109134
row.setInt(ordinal, value)
110135
}
@@ -121,10 +146,18 @@ private[sql] object LONG extends NativeColumnType(LongType, 1, 8) {
121146
buffer.putLong(v)
122147
}
123148

149+
override def append(row: Row, ordinal: Int, buffer: ByteBuffer) {
150+
buffer.putLong(row.getLong(ordinal))
151+
}
152+
124153
override def extract(buffer: ByteBuffer) = {
125154
buffer.getLong()
126155
}
127156

157+
override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int) {
158+
row.setLong(ordinal, buffer.getLong())
159+
}
160+
128161
override def setField(row: MutableRow, ordinal: Int, value: Long) {
129162
row.setLong(ordinal, value)
130163
}
@@ -141,10 +174,18 @@ private[sql] object FLOAT extends NativeColumnType(FloatType, 2, 4) {
141174
buffer.putFloat(v)
142175
}
143176

177+
override def append(row: Row, ordinal: Int, buffer: ByteBuffer) {
178+
buffer.putFloat(row.getFloat(ordinal))
179+
}
180+
144181
override def extract(buffer: ByteBuffer) = {
145182
buffer.getFloat()
146183
}
147184

185+
override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int) {
186+
row.setFloat(ordinal, buffer.getFloat())
187+
}
188+
148189
override def setField(row: MutableRow, ordinal: Int, value: Float) {
149190
row.setFloat(ordinal, value)
150191
}
@@ -161,10 +202,18 @@ private[sql] object DOUBLE extends NativeColumnType(DoubleType, 3, 8) {
161202
buffer.putDouble(v)
162203
}
163204

205+
override def append(row: Row, ordinal: Int, buffer: ByteBuffer) {
206+
buffer.putDouble(row.getDouble(ordinal))
207+
}
208+
164209
override def extract(buffer: ByteBuffer) = {
165210
buffer.getDouble()
166211
}
167212

213+
override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int) {
214+
row.setDouble(ordinal, buffer.getDouble())
215+
}
216+
168217
override def setField(row: MutableRow, ordinal: Int, value: Double) {
169218
row.setDouble(ordinal, value)
170219
}
@@ -178,11 +227,19 @@ private[sql] object DOUBLE extends NativeColumnType(DoubleType, 3, 8) {
178227

179228
private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 4, 1) {
180229
override def append(v: Boolean, buffer: ByteBuffer) {
181-
buffer.put(if (v) 1.toByte else 0.toByte)
230+
buffer.put(if (v) 1: Byte else 0: Byte)
231+
}
232+
233+
override def append(row: Row, ordinal: Int, buffer: ByteBuffer) {
234+
buffer.put(if (row.getBoolean(ordinal)) 1: Byte else 0: Byte)
182235
}
183236

184237
override def extract(buffer: ByteBuffer) = buffer.get() == 1
185238

239+
override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int) {
240+
row.setBoolean(ordinal, buffer.get() == 1)
241+
}
242+
186243
override def setField(row: MutableRow, ordinal: Int, value: Boolean) {
187244
row.setBoolean(ordinal, value)
188245
}
@@ -199,10 +256,18 @@ private[sql] object BYTE extends NativeColumnType(ByteType, 5, 1) {
199256
buffer.put(v)
200257
}
201258

259+
override def append(row: Row, ordinal: Int, buffer: ByteBuffer) {
260+
buffer.put(row.getByte(ordinal))
261+
}
262+
202263
override def extract(buffer: ByteBuffer) = {
203264
buffer.get()
204265
}
205266

267+
override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int) {
268+
row.setByte(ordinal, buffer.get())
269+
}
270+
206271
override def setField(row: MutableRow, ordinal: Int, value: Byte) {
207272
row.setByte(ordinal, value)
208273
}
@@ -219,10 +284,18 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) {
219284
buffer.putShort(v)
220285
}
221286

287+
override def append(row: Row, ordinal: Int, buffer: ByteBuffer) {
288+
buffer.putShort(row.getShort(ordinal))
289+
}
290+
222291
override def extract(buffer: ByteBuffer) = {
223292
buffer.getShort()
224293
}
225294

295+
override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int) {
296+
row.setShort(ordinal, buffer.getShort())
297+
}
298+
226299
override def setField(row: MutableRow, ordinal: Int, value: Short) {
227300
row.setShort(ordinal, value)
228301
}

sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ private[sql] case object PassThrough extends CompressionScheme {
5858
extends compression.Decoder[T] {
5959

6060
override def next(row: MutableRow, ordinal: Int) {
61-
columnType.setField(row, ordinal, columnType.extract(buffer))
61+
columnType.extract(buffer, row, ordinal)
6262
}
6363

6464
override def hasNext = buffer.hasRemaining
@@ -117,27 +117,30 @@ private[sql] case object RunLengthEncoding extends CompressionScheme {
117117
to.putInt(RunLengthEncoding.typeId)
118118

119119
if (from.hasRemaining) {
120-
var currentValue = columnType.extract(from)
120+
val currentValue = new SpecificMutableRow(Seq(columnType.dataType))
121121
var currentRun = 1
122+
val value = new SpecificMutableRow(Seq(columnType.dataType))
123+
124+
columnType.extract(from, currentValue, 0)
122125

123126
while (from.hasRemaining) {
124-
val value = columnType.extract(from)
127+
columnType.extract(from, value, 0)
125128

126-
if (value == currentValue) {
129+
if (value.head == currentValue.head) {
127130
currentRun += 1
128131
} else {
129132
// Writes current run
130-
columnType.append(currentValue, to)
133+
columnType.append(currentValue, 0, to)
131134
to.putInt(currentRun)
132135

133136
// Resets current run
134-
currentValue = value
137+
columnType.copyField(value, 0, currentValue, 0)
135138
currentRun = 1
136139
}
137140
}
138141

139142
// Writes the last run
140-
columnType.append(currentValue, to)
143+
columnType.append(currentValue, 0, to)
141144
to.putInt(currentRun)
142145
}
143146

0 commit comments

Comments
 (0)