diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/HadoopTypeConverter.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/HadoopTypeConverter.scala new file mode 100644 index 000000000000..713c076aee45 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/HadoopTypeConverter.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + + +import org.apache.hadoop.hive.serde2.objectinspector._ +import org.apache.hadoop.hive.serde2.objectinspector.primitive._ + +import org.apache.spark.sql.catalyst.expressions.MutableRow +import org.apache.spark.sql.hive.{HiveInspectors, HiveShim} + +/** + * We can consolidate TableReader.unwrappers and HiveInspectors.wrapperFor to use + * this class. + * + */ +private[hive] object HadoopTypeConverter extends HiveInspectors { + /** + * Builds specific unwrappers ahead of time according to object inspector + * types to avoid pattern matching and branching costs per row. + */ + def unwrappers(fieldRefs: Seq[StructField]): Seq[(Any, MutableRow, Int) => Unit] = fieldRefs.map { + _.getFieldObjectInspector match { + case oi: BooleanObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setBoolean(ordinal, oi.get(value)) + case oi: ByteObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setByte(ordinal, oi.get(value)) + case oi: ShortObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setShort(ordinal, oi.get(value)) + case oi: IntObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setInt(ordinal, oi.get(value)) + case oi: LongObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setLong(ordinal, oi.get(value)) + case oi: FloatObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setFloat(ordinal, oi.get(value)) + case oi: DoubleObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value)) + case oi => + (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrap(value, oi) + } + } + + /** + * Wraps with Hive types based on object inspector. + */ + def wrappers(oi: ObjectInspector): Any => Any = wrapperFor(oi) +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala new file mode 100644 index 000000000000..4dd2d8951b72 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.ql.io.orc.{OrcFile, Reader} +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector + +import org.apache.spark.Logging +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.hive.HiveMetastoreTypes +import org.apache.spark.sql.types.StructType + +private[orc] object OrcFileOperator extends Logging{ + + def getFileReader(pathStr: String, config: Option[Configuration] = None ): Reader = { + val conf = config.getOrElse(new Configuration) + val fspath = new Path(pathStr) + val fs = fspath.getFileSystem(conf) + val orcFiles = listOrcFiles(pathStr, conf) + OrcFile.createReader(fs, orcFiles(0)) + } + + def readSchema(path: String, conf: Option[Configuration]): StructType = { + val reader = getFileReader(path, conf) + val readerInspector: StructObjectInspector = reader.getObjectInspector + .asInstanceOf[StructObjectInspector] + val schema = readerInspector.getTypeName + HiveMetastoreTypes.toDataType(schema).asInstanceOf[StructType] + } + + def getObjectInspector(path: String, conf: Option[Configuration]): StructObjectInspector = { + val reader = getFileReader(path, conf) + val readerInspector: StructObjectInspector = reader.getObjectInspector + .asInstanceOf[StructObjectInspector] + readerInspector + } + + def listOrcFiles(pathStr: String, conf: Configuration): Seq[Path] = { + val origPath = new Path(pathStr) + val fs = origPath.getFileSystem(conf) + val path = origPath.makeQualified(fs) + val paths = SparkHadoopUtil.get.listLeafStatuses(fs, origPath) + .filterNot(_.isDir) + .map(_.getPath) + .filterNot(_.getName.startsWith("_")) + .filterNot(_.getName.startsWith(".")) + + if (paths == null || paths.size == 0) { + throw new IllegalArgumentException( + s"orcFileOperator: path $path does not have valid orc files matching the pattern") + } + paths + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala new file mode 100644 index 000000000000..eda1cffe4981 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import org.apache.hadoop.hive.ql.io.sarg.SearchArgument +import org.apache.hadoop.hive.ql.io.sarg.SearchArgument.Builder +import org.apache.spark.Logging +import org.apache.spark.sql.sources._ + +/** + * It may be optimized by push down partial filters. But we are conservative here. + * Because if some filters fail to be parsed, the tree may be corrupted, + * and cannot be used anymore. + */ +private[orc] object OrcFilters extends Logging { + def createFilter(expr: Array[Filter]): Option[SearchArgument] = { + if (expr.nonEmpty) { + expr.foldLeft(Some(SearchArgument.FACTORY.newBuilder().startAnd()): Option[Builder]) { + (maybeBuilder, e) => createFilter(e, maybeBuilder) + }.map(_.end().build()) + } else { + None + } + } + + private def createFilter(expression: Filter, maybeBuilder: Option[Builder]): Option[Builder] = { + maybeBuilder.flatMap { builder => + expression match { + case p@And(left, right) => + for { + lhs <- createFilter(left, Some(builder.startAnd())) + rhs <- createFilter(right, Some(lhs)) + } yield rhs.end() + case p@Or(left, right) => + for { + lhs <- createFilter(left, Some(builder.startOr())) + rhs <- createFilter(right, Some(lhs)) + } yield rhs.end() + case p@Not(child) => + createFilter(child, Some(builder.startNot())).map(_.end()) + case p@EqualTo(attribute, value) => + Some(builder.equals(attribute, value)) + case p@LessThan(attribute, value) => + Some(builder.lessThan(attribute, value)) + case p@LessThanOrEqual(attribute, value) => + Some(builder.lessThanEquals(attribute, value)) + case p@GreaterThan(attribute, value) => + Some(builder.startNot().lessThanEquals(attribute, value).end()) + case p@GreaterThanOrEqual(attribute, value) => + Some(builder.startNot().lessThan(attribute, value).end()) + case p@IsNull(attribute) => + Some(builder.isNull(attribute)) + case p@IsNotNull(attribute) => + Some(builder.startNot().isNull(attribute).end()) + case p@In(attribute, values) => + Some(builder.in(attribute, values)) + case _ => None + } + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala new file mode 100644 index 000000000000..44ac728b09aa --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.util.Objects + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.ql.io.orc.{OrcSerde, OrcOutputFormat} +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, StructObjectInspector} +import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfoUtils, TypeInfo} +import org.apache.hadoop.io.{Writable, NullWritable} +import org.apache.hadoop.mapred.{RecordWriter, Reporter, JobConf} +import org.apache.hadoop.mapreduce.{TaskID, TaskAttemptContext} + +import org.apache.spark.Logging +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mapred.SparkHadoopMapRedUtil +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.hive.HiveMetastoreTypes +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.sources._ + +/* Implicit conversions */ +import scala.collection.JavaConversions._ + + +private[sql] class DefaultSource extends FSBasedRelationProvider { + def createRelation( + sqlContext: SQLContext, + paths: Array[String], + schema: Option[StructType], + partitionColumns: Option[StructType], + parameters: Map[String, String]): FSBasedRelation ={ + val partitionSpec = partitionColumns.map(PartitionSpec(_, Seq.empty[Partition])) + OrcRelation(paths, parameters, + schema, partitionSpec)(sqlContext) + } +} + + +private[sql] class OrcOutputWriter extends OutputWriter with SparkHadoopMapRedUtil { + + var taskAttemptContext: TaskAttemptContext = _ + var serializer: OrcSerde = _ + var wrappers: Array[Any => Any] = _ + var created = false + var path: String = _ + var dataSchema: StructType = _ + var fieldOIs: Array[ObjectInspector] = _ + var structOI: StructObjectInspector = _ + var outputData: Array[Any] = _ + lazy val recordWriter: RecordWriter[NullWritable, Writable] = { + created = true + val conf = taskAttemptContext.getConfiguration + val taskId: TaskID = taskAttemptContext.getTaskAttemptID.getTaskID + val partition: Int = taskId.getId + val filename = f"part-r-$partition%05d-${System.currentTimeMillis}%015d.orc" + val file = new Path(path, filename) + val fs = file.getFileSystem(conf) + val outputFormat = new OrcOutputFormat() + outputFormat.getRecordWriter(fs, + conf.asInstanceOf[JobConf], + file.toUri.getPath, Reporter.NULL) + .asInstanceOf[org.apache.hadoop.mapred.RecordWriter[NullWritable, Writable]] + } + + override def init(path: String, + dataSchema: StructType, + context: TaskAttemptContext): Unit = { + this.path = path + taskAttemptContext = context + val orcSchema = HiveMetastoreTypes.toMetastoreType(dataSchema) + serializer = new OrcSerde + val typeInfo: TypeInfo = + TypeInfoUtils.getTypeInfoFromTypeString(orcSchema) + structOI = TypeInfoUtils + .getStandardJavaObjectInspectorFromTypeInfo(typeInfo) + .asInstanceOf[StructObjectInspector] + fieldOIs = structOI + .getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray + outputData = new Array[Any](fieldOIs.length) + wrappers = fieldOIs.map(HadoopTypeConverter.wrappers) + } + + override def write(row: Row): Unit = { + var i = 0 + while (i < row.length) { + outputData(i) = wrappers(i)(row(i)) + i += 1 + } + val writable = serializer.serialize(outputData, structOI) + recordWriter.write(NullWritable.get(), writable) + } + + override def close(): Unit = { + if (created) { + recordWriter.close(Reporter.NULL) + } + } +} + +@DeveloperApi +private[sql] case class OrcRelation( + override val paths: Array[String], + parameters: Map[String, String], + maybeSchema: Option[StructType] = None, + maybePartitionSpec: Option[PartitionSpec] = None)( + @transient val sqlContext: SQLContext) + extends FSBasedRelation(paths, maybePartitionSpec) + with Logging { + override val dataSchema: StructType = + maybeSchema.getOrElse(OrcFileOperator.readSchema(paths(0), + Some(sqlContext.sparkContext.hadoopConfiguration))) + + override def outputWriterClass: Class[_ <: OutputWriter] = classOf[OrcOutputWriter] + + override def needConversion: Boolean = false + + override def equals(other: Any): Boolean = other match { + case that: OrcRelation => + paths.toSet == that.paths.toSet && + dataSchema == that.dataSchema && + schema == that.schema && + partitionColumns == that.partitionColumns + case _ => false + } + + override def hashCode(): Int = { + Objects.hashCode( + paths.toSet, + dataSchema, + schema, + maybePartitionSpec) + } + + override def buildScan(requiredColumns: Array[String], + filters: Array[Filter], + inputPaths: Array[String]): RDD[Row] = { + val output = StructType(requiredColumns.map(dataSchema(_))).toAttributes + OrcTableScan(output, this, filters, inputPaths).execute() + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcTableOperations.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcTableOperations.scala new file mode 100644 index 000000000000..2163b0ce70e9 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcTableOperations.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.util._ +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hadoop.hive.ql.io.orc._ +import org.apache.hadoop.io.NullWritable +import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.mapreduce.Job +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat + +import org.apache.spark.rdd.{HadoopRDD, RDD} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.hive.HiveShim +import org.apache.spark.sql.sources.Filter +import org.apache.spark.{Logging, SerializableWritable} + +/* Implicit conversions */ +import scala.collection.JavaConversions._ + +private[orc] case class OrcTableScan(attributes: Seq[Attribute], + @transient relation: OrcRelation, + filters: Array[Filter], + inputPaths: Array[String]) extends Logging { + @transient private val sqlContext = relation.sqlContext + + private def addColumnIds( + output: Seq[Attribute], + relation: OrcRelation, + conf: Configuration): Unit = { + val ids = output.map(a => relation.dataSchema.fieldIndex(a.name): Integer) + val (sortedIds, sortedNames) = ids.zip(attributes.map(_.name)).sorted.unzip + HiveShim.appendReadColumns(conf, sortedIds, sortedNames) + } + + private def buildFilter(job: Job, filters: Array[Filter]): Unit = { + if (ORC_FILTER_PUSHDOWN_ENABLED) { + val conf: Configuration = job.getConfiguration + OrcFilters.createFilter(filters).foreach { f => + conf.set(SARG_PUSHDOWN, toKryo(f)) + conf.setBoolean(ConfVars.HIVEOPTINDEXFILTER.varname, true) + } + } + } + + // Transform all given raw `Writable`s into `Row`s. + private def fillObject( + path: String, + conf: Configuration, + iterator: Iterator[org.apache.hadoop.io.Writable], + nonPartitionKeyAttrs: Seq[(Attribute, Int)], + mutableRow: MutableRow): Iterator[Row] = { + val deserializer = new OrcSerde + val soi = OrcFileOperator.getObjectInspector(path, Some(conf)) + val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { + case (attr, ordinal) => + soi.getStructFieldRef(attr.name.toLowerCase) -> ordinal + }.unzip + val unwrappers = HadoopTypeConverter.unwrappers(fieldRefs) + // Map each tuple to a row object + iterator.map { value => + val raw = deserializer.deserialize(value) + var i = 0 + while (i < fieldRefs.length) { + val fieldValue = soi.getStructFieldData(raw, fieldRefs(i)) + if (fieldValue == null) { + mutableRow.setNullAt(fieldOrdinals(i)) + } else { + unwrappers(i)(fieldValue, mutableRow, fieldOrdinals(i)) + } + i += 1 + } + mutableRow: Row + } + } + + def execute(): RDD[Row] = { + val sc = sqlContext.sparkContext + val job = new Job(sc.hadoopConfiguration) + val conf: Configuration = job.getConfiguration + + buildFilter(job, filters) + addColumnIds(attributes, relation, conf) + FileInputFormat.setInputPaths(job, inputPaths.map(new Path(_)): _*) + + val inputClass = classOf[OrcInputFormat].asInstanceOf[ + Class[_ <: org.apache.hadoop.mapred.InputFormat[NullWritable, Writable]]] + + val rdd = sc.hadoopRDD(conf.asInstanceOf[JobConf], + inputClass, classOf[NullWritable], classOf[Writable]) + .asInstanceOf[HadoopRDD[NullWritable, Writable]] + val wrappedConf = new SerializableWritable(conf) + val rowRdd: RDD[Row] = rdd.mapPartitionsWithInputSplit { case (split: OrcSplit, iter) => + val pathStr = split.getPath.toString + val mutableRow = new SpecificMutableRow(attributes.map(_.dataType)) + fillObject(pathStr, wrappedConf.value, iter.map(_._2), attributes.zipWithIndex, mutableRow) + } + rowRdd + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala new file mode 100644 index 000000000000..b219fbb44ca0 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import com.esotericsoftware.kryo.io.Output +import com.esotericsoftware.kryo.Kryo +import org.apache.commons.codec.binary.Base64 +import org.apache.spark.sql.{SaveMode, DataFrame} + +package object orc { + implicit class OrcContext(sqlContext: HiveContext) { + import sqlContext._ + @scala.annotation.varargs + def orcFile(path: String, paths: String*): DataFrame = { + val pathArray: Array[String] = { + if (paths.isEmpty) { + Array(path) + } else { + paths.toArray ++ Array(path) + } + } + val orcRelation = OrcRelation(pathArray, Map.empty)(sqlContext) + sqlContext.baseRelationToDataFrame(orcRelation) + } + } + + implicit class OrcSchemaRDD(dataFrame: DataFrame) { + def saveAsOrcFile(path: String, mode: SaveMode = SaveMode.Overwrite): Unit = { + dataFrame.save( + path, + source = classOf[DefaultSource].getCanonicalName, + mode) + } + } + + // Flags for orc copression, predicates pushdown, etc. + val orcDefaultCompressVar = "hive.exec.orc.default.compress" + var ORC_FILTER_PUSHDOWN_ENABLED = true + val SARG_PUSHDOWN = "sarg.pushdown" + + def toKryo(input: Any): String = { + val out = new Output(4 * 1024, 10 * 1024 * 1024); + new Kryo().writeObject(out, input); + out.close(); + Base64.encodeBase64String(out.toBytes()); + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala new file mode 100644 index 000000000000..31a829a81124 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.io.File +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.util.Utils +import org.scalatest.{BeforeAndAfterAll, FunSuiteLike} + +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag + + +// The data where the partitioning key exists only in the directory structure. +case class OrcParData(intField: Int, stringField: String) + +// The data that also includes the partitioning key +case class OrcParDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) + +class OrcPartitionDiscoverySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { + val defaultPartitionName = ConfVars.DEFAULTPARTITIONNAME.defaultVal + + def withTempDir(f: File => Unit): Unit = { + val dir = Utils.createTempDir().getCanonicalFile + try f(dir) finally Utils.deleteRecursively(dir) + } + + def makeOrcFile[T <: Product: ClassTag: TypeTag]( + data: Seq[T], path: File): Unit = { + data.toDF().saveAsOrcFile(path.getCanonicalPath) + } + + + def makeOrcFile[T <: Product: ClassTag: TypeTag]( + df: DataFrame, path: File): Unit = { + df.saveAsOrcFile(path.getCanonicalPath) + } + + protected def withTempTable(tableName: String)(f: => Unit): Unit = { + try f finally TestHive.dropTempTable(tableName) + } + + protected def makePartitionDir( + basePath: File, + defaultPartitionName: String, + partitionCols: (String, Any)*): File = { + val partNames = partitionCols.map { case (k, v) => + val valueString = if (v == null || v == "") defaultPartitionName else v.toString + s"$k=$valueString" + } + + val partDir = partNames.foldLeft(basePath) { (parent, child) => + new File(parent, child) + } + + assert(partDir.mkdirs(), s"Couldn't create directory $partDir") + partDir + } + + test("read partitioned table - normal case") { + withTempDir { base => + for { + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } { + makeOrcFile( + (1 to 10).map(i => OrcParData(i, i.toString)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + TestHive.orcFile(base.getCanonicalPath).registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } yield Row(i, i.toString, pi, ps)) + + checkAnswer( + sql("SELECT intField, pi FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + _ <- Seq("foo", "bar") + } yield Row(i, pi)) + + checkAnswer( + sql("SELECT * FROM t WHERE pi = 1"), + for { + i <- 1 to 10 + ps <- Seq("foo", "bar") + } yield Row(i, i.toString, 1, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps = 'foo'"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + } yield Row(i, i.toString, pi, "foo")) + } + } + } + + test("read partitioned table - partition key included in orc file") { + withTempDir { base => + for { + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } { + makeOrcFile( + (1 to 10).map(i => OrcParDataWithKey(i, pi, i.toString, ps)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + TestHive.orcFile(base.getCanonicalPath).registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } yield Row(i, pi, i.toString, ps)) + + checkAnswer( + sql("SELECT intField, pi FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + _ <- Seq("foo", "bar") + } yield Row(i, pi)) + + checkAnswer( + sql("SELECT * FROM t WHERE pi = 1"), + for { + i <- 1 to 10 + ps <- Seq("foo", "bar") + } yield Row(i, 1, i.toString, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps = 'foo'"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + } yield Row(i, pi, i.toString, "foo")) + } + } + } + + + test("read partitioned table - with nulls") { + withTempDir { base => + for { + // Must be `Integer` rather than `Int` here. `null.asInstanceOf[Int]` results in a zero... + pi <- Seq(1, null.asInstanceOf[Integer]) + ps <- Seq("foo", null.asInstanceOf[String]) + } { + makeOrcFile( + (1 to 10).map(i => OrcParData(i, i.toString)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + val orcRelation = load( + "org.apache.spark.sql.hive.orc.DefaultSource", + Map( + "path" -> base.getCanonicalPath, + ConfVars.DEFAULTPARTITIONNAME.varname -> defaultPartitionName)) + + orcRelation.registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, null.asInstanceOf[Integer]) + ps <- Seq("foo", null.asInstanceOf[String]) + } yield Row(i, i.toString, pi, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE pi IS NULL"), + for { + i <- 1 to 10 + ps <- Seq("foo", null.asInstanceOf[String]) + } yield Row(i, i.toString, null, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps IS NULL"), + for { + i <- 1 to 10 + pi <- Seq(1, null.asInstanceOf[Integer]) + } yield Row(i, i.toString, pi, null)) + } + } + } + + test("read partitioned table - with nulls and partition keys are included in Orc file") { + withTempDir { base => + for { + pi <- Seq(1, 2) + ps <- Seq("foo", null.asInstanceOf[String]) + } { + makeOrcFile( + (1 to 10).map(i => OrcParDataWithKey(i, pi, i.toString, ps)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + val orcRelation = load( + "org.apache.spark.sql.hive.orc.DefaultSource", + Map( + "path" -> base.getCanonicalPath, + ConfVars.DEFAULTPARTITIONNAME.varname -> defaultPartitionName)) + + orcRelation.registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + ps <- Seq("foo", null.asInstanceOf[String]) + } yield Row(i, pi, i.toString, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps IS NULL"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + } yield Row(i, pi, i.toString, null)) + } + } + } +} + diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala new file mode 100644 index 000000000000..475af3d4c94e --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.io.File + +import org.apache.hadoop.hive.ql.io.orc.CompressionKind +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.util.Utils +import org.scalatest.{BeforeAndAfterAll, FunSuiteLike} + +case class TestRDDEntry(key: Int, value: String) + +case class NullReflectData( + intField: java.lang.Integer, + longField: java.lang.Long, + floatField: java.lang.Float, + doubleField: java.lang.Double, + booleanField: java.lang.Boolean) + +case class OptionalReflectData( + intField: Option[Int], + longField: Option[Long], + floatField: Option[Float], + doubleField: Option[Double], + booleanField: Option[Boolean]) + +case class Nested(i: Int, s: String) + +case class Data(array: Seq[Int], nested: Nested) + +case class AllDataTypes( + stringField: String, + intField: Int, + longField: Long, + floatField: Float, + doubleField: Double, + shortField: Short, + byteField: Byte, + booleanField: Boolean) + +case class AllDataTypesWithNonPrimitiveType( + stringField: String, + intField: Int, + longField: Long, + floatField: Float, + doubleField: Double, + shortField: Short, + byteField: Byte, + booleanField: Boolean, + array: Seq[Int], + arrayContainsNull: Seq[Option[Int]], + map: Map[Int, Long], + mapValueContainsNull: Map[Int, Option[Long]], + data: Data) + +case class BinaryData(binaryData: Array[Byte]) + +case class Contact(name: String, phone: String) + +case class Person(name: String, age: Int, contacts: Seq[Contact]) + +class OrcQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { + + def getTempFilePath(prefix: String, suffix: String = ""): File = { + val tempFile = File.createTempFile(prefix, suffix) + tempFile.delete() + tempFile + } + + test("Read/Write All Types") { + val tempDir = getTempFilePath("orcTest").getCanonicalPath + val range = (0 to 255) + val data = sparkContext.parallelize(range) + .map(x => + AllDataTypes(s"$x", x, x.toLong, x.toFloat,x.toDouble, x.toShort, x.toByte, x % 2 == 0)) + data.toDF().saveAsOrcFile(tempDir) + checkAnswer( + TestHive.orcFile(tempDir), + data.toDF().collect().toSeq) + Utils.deleteRecursively(new File(tempDir)) + } + + test("read/write binary data") { + val tempDir = getTempFilePath("orcTest").getCanonicalPath + sparkContext.parallelize(BinaryData("test".getBytes("utf8")) :: Nil) + .toDF().saveAsOrcFile(tempDir) + TestHive.orcFile(tempDir) + .map(r => new String(r(0).asInstanceOf[Array[Byte]], "utf8")) + .collect().toSeq == Seq("test") + Utils.deleteRecursively(new File(tempDir)) + } + + test("Read/Write All Types with non-primitive type") { + val tempDir = getTempFilePath("orcTest").getCanonicalPath + val range = (0 to 255) + val data = sparkContext.parallelize(range) + .map(x => AllDataTypesWithNonPrimitiveType( + s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0, + (0 until x), + (0 until x).map(Option(_).filter(_ % 3 == 0)), + (0 until x).map(i => i -> i.toLong).toMap, + (0 until x).map(i => i -> Option(i.toLong)).toMap + (x -> None), + Data((0 until x), Nested(x, s"$x")))) + data.toDF().saveAsOrcFile(tempDir) + + checkAnswer( + TestHive.orcFile(tempDir), + data.toDF().collect().toSeq) + Utils.deleteRecursively(new File(tempDir)) + } + + test("Creating case class RDD table") { + sparkContext.parallelize((1 to 100)) + .map(i => TestRDDEntry(i, s"val_$i")) + .toDF().registerTempTable("tmp") + val rdd = sql("SELECT * FROM tmp").collect().sortBy(_.getInt(0)) + var counter = 1 + rdd.foreach { + // '===' does not like string comparison? + row: Row => { + assert(row.getString(1).equals(s"val_$counter"), + s"row $counter value ${row.getString(1)} does not match val_$counter") + counter = counter + 1 + } + } + } + + test("Simple selection form orc table") { + val tempDir = getTempFilePath("orcTest").getCanonicalPath + val data = sparkContext.parallelize((1 to 10)) + .map(i => Person(s"name_$i", i, (0 until 2).map{ m=> + Contact(s"contact_$m", s"phone_$m") })) + data.toDF().saveAsOrcFile(tempDir) + val f = TestHive.orcFile(tempDir) + f.registerTempTable("tmp") + + // ppd: + // leaf-0 = (LESS_THAN_EQUALS age 5) + // expr = leaf-0 + var rdd = sql("SELECT name FROM tmp where age <= 5") + assert(rdd.count() == 5) + + // ppd: + // leaf-0 = (LESS_THAN_EQUALS age 5) + // expr = (not leaf-0) + rdd = sql("SELECT name, contacts FROM tmp where age > 5") + assert(rdd.count() == 5) + var contacts = rdd.flatMap(t=>t(1).asInstanceOf[Seq[_]]) + assert(contacts.count() == 10) + + // ppd: + // leaf-0 = (LESS_THAN_EQUALS age 5) + // leaf-1 = (LESS_THAN age 8) + // expr = (and (not leaf-0) leaf-1) + rdd = sql("SELECT name, contacts FROM tmp where age > 5 and age < 8") + assert(rdd.count() == 2) + contacts = rdd.flatMap(t=>t(1).asInstanceOf[Seq[_]]) + assert(contacts.count() == 4) + + // ppd: + // leaf-0 = (LESS_THAN age 2) + // leaf-1 = (LESS_THAN_EQUALS age 8) + // expr = (or leaf-0 (not leaf-1)) + rdd = sql("SELECT name, contacts FROM tmp where age < 2 or age > 8") + assert(rdd.count() == 3) + contacts = rdd.flatMap(t=>t(1).asInstanceOf[Seq[_]]) + assert(contacts.count() == 6) + + + Utils.deleteRecursively(new File(tempDir)) + } + + test("save and load case class RDD with Nones as orc") { + val data = OptionalReflectData(None, None, None, None, None) + val rdd = sparkContext.parallelize(data :: Nil) + val tempDir = getTempFilePath("orcTest").getCanonicalPath + rdd.toDF().saveAsOrcFile(tempDir) + val readFile = TestHive.orcFile(tempDir) + val rdd_saved = readFile.collect() + assert(rdd_saved(0).toSeq === Seq.fill(5)(null)) + Utils.deleteRecursively(new File(tempDir)) + } + + // We only support zlib in hive0.12.0 now + test("Default Compression options for writing to an Orcfile") { + // TODO: support other compress codec + var tempDir = getTempFilePath("orcTest").getCanonicalPath + val rdd = sparkContext.parallelize((1 to 100)) + .map(i => TestRDDEntry(i, s"val_$i")) + rdd.toDF().saveAsOrcFile(tempDir) + var actualCodec = OrcFileOperator.getFileReader(tempDir).getCompression + assert(actualCodec == CompressionKind.ZLIB) + Utils.deleteRecursively(new File(tempDir)) + } + + // Following codec is supported in hive-0.13.1, ignore it now + ignore("Other Compression options for writing to an Orcfile - 0.13.1 and above") { + TestHive.sparkContext.hadoopConfiguration.set(orcDefaultCompressVar, "SNAPPY") + var tempDir = getTempFilePath("orcTest").getCanonicalPath + val rdd = sparkContext.parallelize((1 to 100)) + .map(i => TestRDDEntry(i, s"val_$i")) + rdd.toDF().saveAsOrcFile(tempDir) + var actualCodec = OrcFileOperator.getFileReader(tempDir).getCompression + assert(actualCodec == CompressionKind.SNAPPY) + Utils.deleteRecursively(new File(tempDir)) + + TestHive.sparkContext.hadoopConfiguration.set(orcDefaultCompressVar, "NONE") + tempDir = getTempFilePath("orcTest").getCanonicalPath + rdd.toDF().saveAsOrcFile(tempDir) + actualCodec = OrcFileOperator.getFileReader(tempDir).getCompression + assert(actualCodec == CompressionKind.NONE) + Utils.deleteRecursively(new File(tempDir)) + + TestHive.sparkContext.hadoopConfiguration.set(orcDefaultCompressVar, "LZO") + tempDir = getTempFilePath("orcTest").getCanonicalPath + rdd.toDF().saveAsOrcFile(tempDir) + actualCodec = OrcFileOperator.getFileReader(tempDir).getCompression + assert(actualCodec == CompressionKind.LZO) + Utils.deleteRecursively(new File(tempDir)) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcRelationSuite.scala new file mode 100644 index 000000000000..1d8c421b9067 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcRelationSuite.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import org.apache.hadoop.fs.Path + +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.sources.{FSBasedRelationTest} +import org.apache.spark.sql.types._ + + +class FSBasedOrcRelationSuite extends FSBasedRelationTest { + override val dataSourceName: String = classOf[DefaultSource].getCanonicalName + + import sqlContext._ + import sqlContext.implicits._ + + test("save()/load() - partitioned table - simple queries - partition columns in data") { + withTempDir { file => + val basePath = new Path(file.getCanonicalPath) + val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) + val qualifiedBasePath = fs.makeQualified(basePath) + + for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { + val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") + sparkContext + .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1)) + .toDF("a", "b", "p1") + .saveAsOrcFile(partitionDir.toString) + } + + val dataSchemaWithPartition = + StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) + + checkQueries( + load( + source = dataSourceName, + options = Map( + "path" -> file.getCanonicalPath, + "dataSchema" -> dataSchemaWithPartition.json))) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala new file mode 100644 index 000000000000..f86750bcfb6d --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.io.File +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.{Row, QueryTest} +import org.apache.spark.sql.hive.test.TestHive._ + +case class OrcData(intField: Int, stringField: String) + +abstract class OrcSuite extends QueryTest with BeforeAndAfterAll { + var orcTableDir: File = null + var orcTableAsDir: File = null + + override def beforeAll(): Unit = { + super.beforeAll() + + orcTableAsDir = File.createTempFile("orctests", "sparksql") + orcTableAsDir.delete() + orcTableAsDir.mkdir() + + // Hack: to prepare orc data files using hive external tables + orcTableDir = File.createTempFile("orctests", "sparksql") + orcTableDir.delete() + orcTableDir.mkdir() + import org.apache.spark.sql.hive.test.TestHive.implicits._ + + (sparkContext + .makeRDD(1 to 10) + .map(i => OrcData(i, s"part-$i"))) + .toDF.registerTempTable(s"orc_temp_table") + + sql(s""" + create external table normal_orc + ( + intField INT, + stringField STRING + ) + STORED AS orc + location '${orcTableDir.getCanonicalPath}' + """) + + sql( + s"""insert into table normal_orc + select intField, stringField from orc_temp_table""") + + } + + override def afterAll(): Unit = { + orcTableDir.delete() + orcTableAsDir.delete() + } + + test("create temporary orc table") { + checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_source"), Row(10)) + + checkAnswer( + sql("SELECT * FROM normal_orc_source"), + Row(1, "part-1") :: + Row(2, "part-2") :: + Row(3, "part-3") :: + Row(4, "part-4") :: + Row(5, "part-5") :: + Row(6, "part-6") :: + Row(7, "part-7") :: + Row(8, "part-8") :: + Row(9, "part-9") :: + Row(10, "part-10") :: Nil + ) + + checkAnswer( + sql("SELECT * FROM normal_orc_source where intField > 5"), + Row(6, "part-6") :: + Row(7, "part-7") :: + Row(8, "part-8") :: + Row(9, "part-9") :: + Row(10, "part-10") :: Nil + ) + + checkAnswer( + sql("SELECT count(intField), stringField FROM normal_orc_source group by stringField"), + Row(1, "part-1") :: + Row(1, "part-2") :: + Row(1, "part-3") :: + Row(1, "part-4") :: + Row(1, "part-5") :: + Row(1, "part-6") :: + Row(1, "part-7") :: + Row(1, "part-8") :: + Row(1, "part-9") :: + Row(1, "part-10") :: Nil + ) + + } + + test("create temporary orc table as") { + checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_as_source"), Row(10)) + + checkAnswer( + sql("SELECT * FROM normal_orc_source"), + Row(1, "part-1") :: + Row(2, "part-2") :: + Row(3, "part-3") :: + Row(4, "part-4") :: + Row(5, "part-5") :: + Row(6, "part-6") :: + Row(7, "part-7") :: + Row(8, "part-8") :: + Row(9, "part-9") :: + Row(10, "part-10") :: Nil + ) + + checkAnswer( + sql("SELECT * FROM normal_orc_source where intField > 5"), + Row(6, "part-6") :: + Row(7, "part-7") :: + Row(8, "part-8") :: + Row(9, "part-9") :: + Row(10, "part-10") :: Nil + ) + + checkAnswer( + sql("SELECT count(intField), stringField FROM normal_orc_source group by stringField"), + Row(1, "part-1") :: + Row(1, "part-2") :: + Row(1, "part-3") :: + Row(1, "part-4") :: + Row(1, "part-5") :: + Row(1, "part-6") :: + Row(1, "part-7") :: + Row(1, "part-8") :: + Row(1, "part-9") :: + Row(1, "part-10") :: Nil + ) + + } + + test("appending insert") { + sql("insert into table normal_orc_source select * from orc_temp_table where intField > 5") + checkAnswer( + sql("select * from normal_orc_source"), + Row(1, "part-1") :: + Row(2, "part-2") :: + Row(3, "part-3") :: + Row(4, "part-4") :: + Row(5, "part-5") :: + Row(6, "part-6") :: + Row(6, "part-6") :: + Row(7, "part-7") :: + Row(7, "part-7") :: + Row(8, "part-8") :: + Row(8, "part-8") :: + Row(9, "part-9") :: + Row(9, "part-9") :: + Row(10, "part-10") :: + Row(10, "part-10") :: Nil + ) + } + + test("overwrite insert") { + sql("insert overwrite table normal_orc_as_source select * " + + "from orc_temp_table where intField > 5") + checkAnswer( + sql("select * from normal_orc_as_source"), + Row(6, "part-6") :: + Row(7, "part-7") :: + Row(8, "part-8") :: + Row(9, "part-9") :: + Row(10, "part-10") :: Nil + ) + } +} + +class OrcSourceSuite extends OrcSuite { + override def beforeAll(): Unit = { + super.beforeAll() + + sql( s""" + create temporary table normal_orc_source + USING org.apache.spark.sql.hive.orc + OPTIONS ( + path '${new File(orcTableDir.getAbsolutePath).getCanonicalPath}' + ) + """) + + sql( s""" + create temporary table normal_orc_as_source + USING org.apache.spark.sql.hive.orc + OPTIONS ( + path '${new File(orcTableAsDir.getAbsolutePath).getCanonicalPath}' + ) + as select * from orc_temp_table + """) + } +}