Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,29 @@

package org.apache.spark.sql.execution.datasources.orc

import java.io.File

import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.{JobID, TaskAttemptID, TaskID, TaskType}
import org.apache.hadoop.mapreduce.lib.input.FileSplit
import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
import org.apache.orc.TypeDescription

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase
import org.apache.spark.sql.execution.vectorized.{OnHeapColumnVector, WritableColumnVector}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.unsafe.types.UTF8String.fromString

class OrcColumnarBatchReaderSuite extends QueryTest with SharedSparkSession {

import testImplicits._

private val dataSchema = StructType.fromDDL("col1 int, col2 int")
private val partitionSchema = StructType.fromDDL("p1 string, p2 string")
private val partitionValues = InternalRow(fromString("partValue1"), fromString("partValue2"))
Expand Down Expand Up @@ -77,4 +90,66 @@ class OrcColumnarBatchReaderSuite extends QueryTest with SharedSparkSession {
assert(p1.getUTF8String(0) === partitionValues.getUTF8String(0))
}
}

test("SPARK-33593: OrcColumnarBatchReader - partition column types") {
withTempPath { dir =>
Seq(1).toDF().repartition(1).write.orc(dir.getCanonicalPath)

val dataTypes =
Seq(StringType, BooleanType, ByteType, BinaryType, ShortType, IntegerType, LongType,
FloatType, DoubleType, DecimalType(25, 5), DateType, TimestampType)

val constantValues =
Seq(
UTF8String.fromString("a string"),
true,
1.toByte,
"Spark SQL".getBytes,
2.toShort,
3,
Long.MaxValue,
0.25.toFloat,
0.75D,
Decimal("1234.23456"),
DateTimeUtils.fromJavaDate(java.sql.Date.valueOf("2015-01-01")),
DateTimeUtils.fromJavaTimestamp(java.sql.Timestamp.valueOf("2015-01-01 23:50:59.123")))

dataTypes.zip(constantValues).foreach { case (dt, v) =>
val schema = StructType(StructField("col1", IntegerType) :: StructField("pcol", dt) :: Nil)
val partitionValues = new GenericInternalRow(Array(v))
val file = new File(SpecificParquetRecordReaderBase.listDirectory(dir).get(0))
val fileSplit = new FileSplit(new Path(file.getCanonicalPath), 0L, file.length, Array.empty)
val taskConf = sqlContext.sessionState.newHadoopConf()
val orcFileSchema = TypeDescription.fromString(schema.simpleString)
val vectorizedReader = new OrcColumnarBatchReader(4096)
val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId)

try {
vectorizedReader.initialize(fileSplit, taskAttemptContext)
vectorizedReader.initBatch(
orcFileSchema,
schema.toArray,
Array(0, -1),
Array(-1, 0),
partitionValues)
vectorizedReader.nextKeyValue()
val row = vectorizedReader.getCurrentValue.getRow(0)

// Use `GenericMutableRow` by explicitly copying rather than `ColumnarBatch`
// in order to use get(...) method which is not implemented in `ColumnarBatch`.
val actual = row.copy().get(1, dt)
val expected = v
if (dt.isInstanceOf[BinaryType]) {
assert(actual.asInstanceOf[Array[Byte]]
sameElements expected.asInstanceOf[Array[Byte]])
} else {
assert(actual == expected)
}
} finally {
vectorizedReader.close()
}
}
}
}
}