Skip to content

Commit 97bbc4e

Browse files
committed
Optimizes hive.TableReader by by providing specific Writable unwrappers a head of time
1 parent 3dc1f94 commit 97bbc4e

File tree

2 files changed

+75
-50
lines changed

2 files changed

+75
-50
lines changed

sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ private[sql] case class InMemoryColumnarTableScan(
189189
readPartitions.setValue(0)
190190
readBatches.setValue(0)
191191

192-
relation.cachedColumnBuffers.mapPartitions { iterator =>
192+
relation.cachedColumnBuffers.mapPartitions { cachedBatchIterator =>
193193
val partitionFilter = newPredicate(
194194
partitionFilters.reduceOption(And).getOrElse(Literal(true)),
195195
relation.partitionStatistics.schema)
@@ -211,7 +211,7 @@ private[sql] case class InMemoryColumnarTableScan(
211211
}
212212

213213
val nextRow = new SpecificMutableRow(requestedColumnDataTypes)
214-
val rows = iterator
214+
val rows = cachedBatchIterator
215215
// Skip pruned batches
216216
.filter { cachedBatch =>
217217
if (inMemoryPartitionPruningEnabled && !partitionFilter(cachedBatch.stats)) {
@@ -242,7 +242,7 @@ private[sql] case class InMemoryColumnarTableScan(
242242
nextRow
243243
}
244244

245-
override def hasNext = columnAccessors.head.hasNext
245+
override def hasNext = columnAccessors(0).hasNext
246246
}
247247
}
248248

sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala

Lines changed: 72 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.hive
1919

20+
import scala.collection.JavaConversions._
21+
2022
import org.apache.hadoop.conf.Configuration
2123
import org.apache.hadoop.fs.{Path, PathFilter}
2224
import org.apache.hadoop.hive.metastore.api.hive_metastoreConstants._
@@ -25,16 +27,14 @@ import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table =>
2527
import org.apache.hadoop.hive.ql.plan.{PlanUtils, TableDesc}
2628
import org.apache.hadoop.hive.serde2.Deserializer
2729
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector
28-
30+
import org.apache.hadoop.hive.serde2.objectinspector.primitive._
2931
import org.apache.hadoop.io.Writable
3032
import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf}
3133

3234
import org.apache.spark.SerializableWritable
3335
import org.apache.spark.broadcast.Broadcast
3436
import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD}
35-
36-
import org.apache.spark.sql.catalyst.expressions.{Attribute, Row, GenericMutableRow, Literal, Cast}
37-
import org.apache.spark.sql.catalyst.types.DataType
37+
import org.apache.spark.sql.catalyst.expressions._
3838

3939
/**
4040
* A trait for subclasses that handle table scans.
@@ -108,12 +108,12 @@ class HadoopTableReader(
108108
val hadoopRDD = createHadoopRdd(tableDesc, inputPathStr, ifc)
109109

110110
val attrsWithIndex = attributes.zipWithIndex
111-
val mutableRow = new GenericMutableRow(attrsWithIndex.length)
111+
val mutableRow = new SpecificMutableRow(attributes.map(_.dataType))
112+
112113
val deserializedHadoopRDD = hadoopRDD.mapPartitions { iter =>
113114
val hconf = broadcastedHiveConf.value.value
114115
val deserializer = deserializerClass.newInstance()
115116
deserializer.initialize(hconf, tableDesc.getProperties)
116-
117117
HadoopTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow)
118118
}
119119

@@ -164,33 +164,32 @@ class HadoopTableReader(
164164
val tableDesc = relation.tableDesc
165165
val broadcastedHiveConf = _broadcastedHiveConf
166166
val localDeserializer = partDeserializer
167-
val mutableRow = new GenericMutableRow(attributes.length)
168-
169-
// split the attributes (output schema) into 2 categories:
170-
// (partition keys, ordinal), (normal attributes, ordinal), the ordinal mean the
171-
// index of the attribute in the output Row.
172-
val (partitionKeys, attrs) = attributes.zipWithIndex.partition(attr => {
173-
relation.partitionKeys.indexOf(attr._1) >= 0
174-
})
175-
176-
def fillPartitionKeys(parts: Array[String], row: GenericMutableRow) = {
177-
partitionKeys.foreach { case (attr, ordinal) =>
178-
// get partition key ordinal for a given attribute
179-
val partOridinal = relation.partitionKeys.indexOf(attr)
180-
row(ordinal) = Cast(Literal(parts(partOridinal)), attr.dataType).eval(null)
167+
val mutableRow = new SpecificMutableRow(attributes.map(_.dataType))
168+
169+
// Splits all attributes into two groups, partition key attributes and those that are not.
170+
// Attached indices indicate the position of each attribute in the output schema.
171+
val (partitionKeyAttrs, nonPartitionKeyAttrs) =
172+
attributes.zipWithIndex.partition { case (attr, _) =>
173+
relation.partitionKeys.contains(attr)
174+
}
175+
176+
def fillPartitionKeys(rawPartValues: Array[String], row: MutableRow) = {
177+
partitionKeyAttrs.foreach { case (attr, ordinal) =>
178+
val partOrdinal = relation.partitionKeys.indexOf(attr)
179+
row(ordinal) = Cast(Literal(rawPartValues(partOrdinal)), attr.dataType).eval(null)
181180
}
182181
}
183-
// fill the partition key for the given MutableRow Object
182+
183+
// Fill all partition keys to the given MutableRow object
184184
fillPartitionKeys(partValues, mutableRow)
185185

186-
val hivePartitionRDD = createHadoopRdd(tableDesc, inputPathStr, ifc)
187-
hivePartitionRDD.mapPartitions { iter =>
186+
createHadoopRdd(tableDesc, inputPathStr, ifc).mapPartitions { iter =>
188187
val hconf = broadcastedHiveConf.value.value
189188
val deserializer = localDeserializer.newInstance()
190189
deserializer.initialize(hconf, partProps)
191190

192-
// fill the non partition key attributes
193-
HadoopTableReader.fillObject(iter, deserializer, attrs, mutableRow)
191+
// fill the non partition key attributes
192+
HadoopTableReader.fillObject(iter, deserializer, nonPartitionKeyAttrs, mutableRow)
194193
}
195194
}.toSeq
196195

@@ -257,38 +256,64 @@ private[hive] object HadoopTableReader extends HiveInspectors {
257256
}
258257

259258
/**
260-
* Transform the raw data(Writable object) into the Row object for an iterable input
261-
* @param iter Iterable input which represented as Writable object
262-
* @param deserializer Deserializer associated with the input writable object
263-
* @param attrs Represents the row attribute names and its zero-based position in the MutableRow
264-
* @param row reusable MutableRow object
265-
*
266-
* @return Iterable Row object that transformed from the given iterable input.
259+
* Transform all given raw `Writable`s into `Row`s.
260+
*
261+
* @param iterator Iterator of all `Writable`s to be transformed
262+
* @param deserializer The `Deserializer` associated with the input `Writable`
263+
* @param nonPartitionKeyAttrs Attributes that should be filled together with their corresponding
264+
* positions in the output schema
265+
* @param mutableRow A reusable `MutableRow` that should be filled
266+
* @return An `Iterator[Row]` transformed from `iterator`
267267
*/
268268
def fillObject(
269-
iter: Iterator[Writable],
269+
iterator: Iterator[Writable],
270270
deserializer: Deserializer,
271-
attrs: Seq[(Attribute, Int)],
272-
row: GenericMutableRow): Iterator[Row] = {
271+
nonPartitionKeyAttrs: Seq[(Attribute, Int)],
272+
mutableRow: MutableRow): Iterator[Row] = {
273+
273274
val soi = deserializer.getObjectInspector().asInstanceOf[StructObjectInspector]
274-
// get the field references according to the attributes(output of the reader) required
275-
val fieldRefs = attrs.map { case (attr, idx) => (soi.getStructFieldRef(attr.name), idx) }
275+
val fieldRefsWithOrdinals = {
276+
val allFieldRefs = soi.getAllStructFieldRefs
277+
nonPartitionKeyAttrs.map { case (_, ordinal) => allFieldRefs(ordinal) -> ordinal }
278+
}
279+
280+
// Builds specific unwrappers ahead of time according to object inspector types to avoid pattern
281+
// matching and branching costs per row.
282+
val unwrappers: Seq[(Any, MutableRow, Int) => Unit] =
283+
soi.getAllStructFieldRefs.map {
284+
_.getFieldObjectInspector match {
285+
case oi: BooleanObjectInspector =>
286+
(value: Any, row: MutableRow, ordinal: Int) => row.setBoolean(ordinal, oi.get(value))
287+
case oi: ByteObjectInspector =>
288+
(value: Any, row: MutableRow, ordinal: Int) => row.setByte(ordinal, oi.get(value))
289+
case oi: ShortObjectInspector =>
290+
(value: Any, row: MutableRow, ordinal: Int) => row.setShort(ordinal, oi.get(value))
291+
case oi: IntObjectInspector =>
292+
(value: Any, row: MutableRow, ordinal: Int) => row.setInt(ordinal, oi.get(value))
293+
case oi: LongObjectInspector =>
294+
(value: Any, row: MutableRow, ordinal: Int) => row.setLong(ordinal, oi.get(value))
295+
case oi: FloatObjectInspector =>
296+
(value: Any, row: MutableRow, ordinal: Int) => row.setFloat(ordinal, oi.get(value))
297+
case oi: DoubleObjectInspector =>
298+
(value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value))
299+
case oi =>
300+
(value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrapData(value, oi)
301+
}
302+
}.toSeq
276303

277304
// Map each tuple to a row object
278-
iter.map { value =>
305+
iterator.map { value =>
279306
val raw = deserializer.deserialize(value)
280-
var idx = 0;
281-
while (idx < fieldRefs.length) {
282-
val fieldRef = fieldRefs(idx)._1
283-
val fieldIdx = fieldRefs(idx)._2
307+
var i = 0
308+
while (i < fieldRefsWithOrdinals.length) {
309+
val fieldRef = fieldRefsWithOrdinals(i)._1
310+
val fieldOrdinal= fieldRefsWithOrdinals(i)._2
284311
val fieldValue = soi.getStructFieldData(raw, fieldRef)
285-
286-
row(fieldIdx) = unwrapData(fieldValue, fieldRef.getFieldObjectInspector())
287-
288-
idx += 1
312+
unwrappers(i)(fieldValue, mutableRow, fieldOrdinal)
313+
i += 1
289314
}
290315

291-
row: Row
316+
mutableRow: Row
292317
}
293318
}
294319
}

0 commit comments

Comments
 (0)