diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index f16d824201e7..0c3d9a4895fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -396,6 +396,7 @@ case class DataSource( hs.partitionSchema.map(_.name), "in the partition schema", equality) + DataSourceUtils.verifyReadSchema(hs.fileFormat, hs.dataSchema) case _ => SchemaUtils.checkColumnNameDuplication( relation.schema.map(_.name), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala index c5347218c4b4..82e99190ecf1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala @@ -17,10 +17,7 @@ package org.apache.spark.sql.execution.datasources -import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat -import org.apache.spark.sql.execution.datasources.json.JsonFileFormat -import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat -import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.types._ @@ -42,65 +39,14 @@ object DataSourceUtils { /** * Verify if the schema is supported in datasource. This verification should be done - * in a driver side, e.g., `prepareWrite`, `buildReader`, and `buildReaderWithPartitionValues` - * in `FileFormat`. - * - * Unsupported data types of csv, json, orc, and parquet are as follows; - * csv -> R/W: Interval, Null, Array, Map, Struct - * json -> W: Interval - * orc -> W: Interval, Null - * parquet -> R/W: Interval, Null + * in a driver side. */ private def verifySchema(format: FileFormat, schema: StructType, isReadPath: Boolean): Unit = { - def throwUnsupportedException(dataType: DataType): Unit = { - throw new UnsupportedOperationException( - s"$format data source does not support ${dataType.simpleString} data type.") + schema.foreach { field => + if (!format.supportDataType(field.dataType, isReadPath)) { + throw new AnalysisException( + s"$format data source does not support ${field.dataType.simpleString} data type.") + } } - - def verifyType(dataType: DataType): Unit = dataType match { - case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | - StringType | BinaryType | DateType | TimestampType | _: DecimalType => - - // All the unsupported types for CSV - case _: NullType | _: CalendarIntervalType | _: StructType | _: ArrayType | _: MapType - if format.isInstanceOf[CSVFileFormat] => - throwUnsupportedException(dataType) - - case st: StructType => st.foreach { f => verifyType(f.dataType) } - - case ArrayType(elementType, _) => verifyType(elementType) - - case MapType(keyType, valueType, _) => - verifyType(keyType) - verifyType(valueType) - - case udt: UserDefinedType[_] => verifyType(udt.sqlType) - - // Interval type not supported in all the write path - case _: CalendarIntervalType if !isReadPath => - throwUnsupportedException(dataType) - - // JSON and ORC don't support an Interval type, but we pass it in read pass - // for back-compatibility. - case _: CalendarIntervalType if format.isInstanceOf[JsonFileFormat] || - format.isInstanceOf[OrcFileFormat] => - - // Interval type not supported in the other read path - case _: CalendarIntervalType => - throwUnsupportedException(dataType) - - // For JSON & ORC backward-compatibility - case _: NullType if format.isInstanceOf[JsonFileFormat] || - (isReadPath && format.isInstanceOf[OrcFileFormat]) => - - // Null type not supported in the other path - case _: NullType => - throwUnsupportedException(dataType) - - // We keep this default case for safeguards - case _ => throwUnsupportedException(dataType) - } - - schema.foreach(field => verifyType(field.dataType)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala index 023e12788829..2c162e23644e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, StructType} /** @@ -57,7 +57,7 @@ trait FileFormat { dataSchema: StructType): OutputWriterFactory /** - * Returns whether this format support returning columnar batch or not. + * Returns whether this format supports returning columnar batch or not. * * TODO: we should just have different traits for the different formats. */ @@ -152,6 +152,11 @@ trait FileFormat { } } + /** + * Returns whether this format supports the given [[DataType]] in read/write path. + * By default all data types are supported. + */ + def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = true } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 52da8356ab83..7c6ab4bc922f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -96,9 +96,11 @@ object FileFormatWriter extends Logging { val caseInsensitiveOptions = CaseInsensitiveMap(options) + val dataSchema = dataColumns.toStructType + DataSourceUtils.verifyWriteSchema(fileFormat, dataSchema) // Note: prepareWrite has side effect. It sets "job". val outputWriterFactory = - fileFormat.prepareWrite(sparkSession, job, caseInsensitiveOptions, dataColumns.toStructType) + fileFormat.prepareWrite(sparkSession, job, caseInsensitiveOptions, dataSchema) val description = new WriteJobDescription( uuid = UUID.randomUUID().toString, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index fa366ccce6b6..aeb40e5a4131 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -66,7 +66,6 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - DataSourceUtils.verifyWriteSchema(this, dataSchema) val conf = job.getConfiguration val csvOptions = new CSVOptions( options, @@ -98,7 +97,6 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { - DataSourceUtils.verifyReadSchema(this, dataSchema) val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) @@ -153,6 +151,15 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { override def hashCode(): Int = getClass.hashCode() override def equals(other: Any): Boolean = other.isInstanceOf[CSVFileFormat] + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + case _: AtomicType => true + + case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + + case _ => false + } + } private[csv] class CsvOutputWriter( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 383bff1375a9..a9241afba537 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSON import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { @@ -65,8 +65,6 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - DataSourceUtils.verifyWriteSchema(this, dataSchema) - val conf = job.getConfiguration val parsedOptions = new JSONOptions( options, @@ -98,8 +96,6 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { - DataSourceUtils.verifyReadSchema(this, dataSchema) - val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) @@ -148,6 +144,23 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { override def hashCode(): Int = getClass.hashCode() override def equals(other: Any): Boolean = other.isInstanceOf[JsonFileFormat] + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + case _: AtomicType => true + + case st: StructType => st.forall { f => supportDataType(f.dataType, isReadPath) } + + case ArrayType(elementType, _) => supportDataType(elementType, isReadPath) + + case MapType(keyType, valueType, _) => + supportDataType(keyType, isReadPath) && supportDataType(valueType, isReadPath) + + case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + + case _: NullType => true + + case _ => false + } } private[json] class JsonOutputWriter( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index df488a748e3e..3a8c0add8c2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -89,8 +89,6 @@ class OrcFileFormat job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - DataSourceUtils.verifyWriteSchema(this, dataSchema) - val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf) val conf = job.getConfiguration @@ -143,8 +141,6 @@ class OrcFileFormat filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { - DataSourceUtils.verifyReadSchema(this, dataSchema) - if (sparkSession.sessionState.conf.orcFilterPushDown) { OrcFilters.createFilter(dataSchema, filters).foreach { f => OrcInputFormat.setSearchArgument(hadoopConf, f, dataSchema.fieldNames) @@ -228,4 +224,21 @@ class OrcFileFormat } } } + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + case _: AtomicType => true + + case st: StructType => st.forall { f => supportDataType(f.dataType, isReadPath) } + + case ArrayType(elementType, _) => supportDataType(elementType, isReadPath) + + case MapType(keyType, valueType, _) => + supportDataType(keyType, isReadPath) && supportDataType(valueType, isReadPath) + + case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + + case _: NullType => isReadPath + + case _ => false + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 52a18abb5524..b86b97ec7b10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -78,8 +78,6 @@ class ParquetFileFormat job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - DataSourceUtils.verifyWriteSchema(this, dataSchema) - val parquetOptions = new ParquetOptions(options, sparkSession.sessionState.conf) val conf = ContextUtil.getConfiguration(job) @@ -303,8 +301,6 @@ class ParquetFileFormat filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { - DataSourceUtils.verifyReadSchema(this, dataSchema) - hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName) hadoopConf.set( ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA, @@ -454,6 +450,21 @@ class ParquetFileFormat } } } + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + case _: AtomicType => true + + case st: StructType => st.forall { f => supportDataType(f.dataType, isReadPath) } + + case ArrayType(elementType, _) => supportDataType(elementType, isReadPath) + + case MapType(keyType, valueType, _) => + supportDataType(keyType, isReadPath) && supportDataType(valueType, isReadPath) + + case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + + case _ => false + } } object ParquetFileFormat extends Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index e93908da4353..8661a5395ac4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.types.{DataType, StringType, StructType} import org.apache.spark.util.SerializableConfiguration /** @@ -47,11 +47,6 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { throw new AnalysisException( s"Text data source supports only a single column, and you have ${schema.size} columns.") } - val tpe = schema(0).dataType - if (tpe != StringType) { - throw new AnalysisException( - s"Text data source supports only a string column, but you have ${tpe.simpleString}.") - } } override def isSplitable( @@ -141,6 +136,9 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { } } } + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = + dataType == StringType } class TextOutputWriter( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 86f9647b4ac4..a7ce952b70ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -205,63 +205,121 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo } } + // Text file format only supports string type + test("SPARK-24691 error handling for unsupported types - text") { + withTempDir { dir => + // write path + val textDir = new File(dir, "text").getCanonicalPath + var msg = intercept[AnalysisException] { + Seq(1).toDF.write.text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support int data type")) + + msg = intercept[AnalysisException] { + Seq(1.2).toDF.write.text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support double data type")) + + msg = intercept[AnalysisException] { + Seq(true).toDF.write.text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support boolean data type")) + + msg = intercept[AnalysisException] { + Seq(1).toDF("a").selectExpr("struct(a)").write.text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support struct data type")) + + msg = intercept[AnalysisException] { + Seq((Map("Tesla" -> 3))).toDF("cars").write.mode("overwrite").text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support map data type")) + + msg = intercept[AnalysisException] { + Seq((Array("Tesla", "Chevy", "Ford"))).toDF("brands") + .write.mode("overwrite").text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support array data type")) + + // read path + Seq("aaa").toDF.write.mode("overwrite").text(textDir) + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", IntegerType, true) :: Nil) + spark.read.schema(schema).text(textDir).collect() + }.getMessage + assert(msg.contains("Text data source does not support int data type")) + + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", DoubleType, true) :: Nil) + spark.read.schema(schema).text(textDir).collect() + }.getMessage + assert(msg.contains("Text data source does not support double data type")) + + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", BooleanType, true) :: Nil) + spark.read.schema(schema).text(textDir).collect() + }.getMessage + assert(msg.contains("Text data source does not support boolean data type")) + } + } + // Unsupported data types of csv, json, orc, and parquet are as follows; - // csv -> R/W: Interval, Null, Array, Map, Struct - // json -> W: Interval - // orc -> W: Interval, Null + // csv -> R/W: Null, Array, Map, Struct + // json -> R/W: Interval + // orc -> R/W: Interval, W: Null // parquet -> R/W: Interval, Null test("SPARK-24204 error handling for unsupported Array/Map/Struct types - csv") { withTempDir { dir => val csvDir = new File(dir, "csv").getCanonicalPath - var msg = intercept[UnsupportedOperationException] { + var msg = intercept[AnalysisException] { Seq((1, "Tesla")).toDF("a", "b").selectExpr("struct(a, b)").write.csv(csvDir) }.getMessage assert(msg.contains("CSV data source does not support struct data type")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { val schema = StructType.fromDDL("a struct") spark.range(1).write.mode("overwrite").csv(csvDir) spark.read.schema(schema).csv(csvDir).collect() }.getMessage assert(msg.contains("CSV data source does not support struct data type")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { Seq((1, Map("Tesla" -> 3))).toDF("id", "cars").write.mode("overwrite").csv(csvDir) }.getMessage assert(msg.contains("CSV data source does not support map data type")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { val schema = StructType.fromDDL("a map") spark.range(1).write.mode("overwrite").csv(csvDir) spark.read.schema(schema).csv(csvDir).collect() }.getMessage assert(msg.contains("CSV data source does not support map data type")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { Seq((1, Array("Tesla", "Chevy", "Ford"))).toDF("id", "brands") .write.mode("overwrite").csv(csvDir) }.getMessage assert(msg.contains("CSV data source does not support array data type")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { val schema = StructType.fromDDL("a array") spark.range(1).write.mode("overwrite").csv(csvDir) spark.read.schema(schema).csv(csvDir).collect() }.getMessage assert(msg.contains("CSV data source does not support array data type")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { Seq((1, new UDT.MyDenseVector(Array(0.25, 2.25, 4.25)))).toDF("id", "vectors") .write.mode("overwrite").csv(csvDir) }.getMessage - assert(msg.contains("CSV data source does not support array data type")) + assert(msg.contains("CSV data source does not support mydensevector data type")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { val schema = StructType(StructField("a", new UDT.MyDenseVectorUDT(), true) :: Nil) spark.range(1).write.mode("overwrite").csv(csvDir) spark.read.schema(schema).csv(csvDir).collect() }.getMessage - assert(msg.contains("CSV data source does not support array data type.")) + assert(msg.contains("CSV data source does not support mydensevector data type.")) } } @@ -276,17 +334,17 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo }.getMessage assert(msg.contains("Cannot save interval data type into external storage.")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { spark.udf.register("testType", () => new IntervalData()) sql("select testType()").write.format(format).mode("overwrite").save(tempDir) }.getMessage assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support calendarinterval data type.")) + .contains(s"$format data source does not support interval data type.")) } // read path Seq("parquet", "csv").foreach { format => - var msg = intercept[UnsupportedOperationException] { + var msg = intercept[AnalysisException] { val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil) spark.range(1).write.format(format).mode("overwrite").save(tempDir) spark.read.schema(schema).format(format).load(tempDir).collect() @@ -294,26 +352,13 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo assert(msg.toLowerCase(Locale.ROOT) .contains(s"$format data source does not support calendarinterval data type.")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil) spark.range(1).write.format(format).mode("overwrite").save(tempDir) spark.read.schema(schema).format(format).load(tempDir).collect() }.getMessage assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support calendarinterval data type.")) - } - - // We expect the types below should be passed for backward-compatibility - Seq("orc", "json").foreach { format => - // Interval type - var schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil) - spark.range(1).write.format(format).mode("overwrite").save(tempDir) - spark.read.schema(schema).format(format).load(tempDir).collect() - - // UDT having interval data - schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil) - spark.range(1).write.format(format).mode("overwrite").save(tempDir) - spark.read.schema(schema).format(format).load(tempDir).collect() + .contains(s"$format data source does not support interval data type.")) } } } @@ -324,13 +369,13 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo Seq("orc").foreach { format => // write path - var msg = intercept[UnsupportedOperationException] { + var msg = intercept[AnalysisException] { sql("select null").write.format(format).mode("overwrite").save(tempDir) }.getMessage assert(msg.toLowerCase(Locale.ROOT) .contains(s"$format data source does not support null data type.")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { spark.udf.register("testType", () => new NullData()) sql("select testType()").write.format(format).mode("overwrite").save(tempDir) }.getMessage @@ -353,13 +398,13 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo Seq("parquet", "csv").foreach { format => // write path - var msg = intercept[UnsupportedOperationException] { + var msg = intercept[AnalysisException] { sql("select null").write.format(format).mode("overwrite").save(tempDir) }.getMessage assert(msg.toLowerCase(Locale.ROOT) .contains(s"$format data source does not support null data type.")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { spark.udf.register("testType", () => new NullData()) sql("select testType()").write.format(format).mode("overwrite").save(tempDir) }.getMessage @@ -367,7 +412,7 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo .contains(s"$format data source does not support null data type.")) // read path - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { val schema = StructType(StructField("a", NullType, true) :: Nil) spark.range(1).write.format(format).mode("overwrite").save(tempDir) spark.read.schema(schema).format(format).load(tempDir).collect() @@ -375,7 +420,7 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo assert(msg.toLowerCase(Locale.ROOT) .contains(s"$format data source does not support null data type.")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { val schema = StructType(StructField("a", new NullUDT(), true) :: Nil) spark.range(1).write.format(format).mode("overwrite").save(tempDir) spark.read.schema(schema).format(format).load(tempDir).collect() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 270ed7f80197..ca95aad3976e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -2513,7 +2513,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { test("alter datasource table add columns - text format not supported") { withTable("t1") { - sql("CREATE TABLE t1 (c1 int) USING text") + sql("CREATE TABLE t1 (c1 string) USING text") val e = intercept[AnalysisException] { sql("ALTER TABLE t1 ADD COLUMNS (c2 int)") }.getMessage diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 8764f0c42cf9..bceaf1a9ec06 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{BlockLocation, FileStatus, Path, RawLocalFileSystem} import org.apache.hadoop.mapreduce.Job -import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.BucketSpec diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index dd2144c5fcea..20090696ec3f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.orc.OrcOptions import org.apache.spark.sql.hive.{HiveInspectors, HiveShim} import org.apache.spark.sql.sources.{Filter, _} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration /** @@ -72,7 +72,6 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - DataSourceUtils.verifyWriteSchema(this, dataSchema) val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf) @@ -123,7 +122,6 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { - DataSourceUtils.verifyReadSchema(this, dataSchema) if (sparkSession.sessionState.conf.orcFilterPushDown) { // Sets pushed predicates @@ -178,6 +176,23 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable } } } + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + case _: AtomicType => true + + case st: StructType => st.forall { f => supportDataType(f.dataType, isReadPath) } + + case ArrayType(elementType, _) => supportDataType(elementType, isReadPath) + + case MapType(keyType, valueType, _) => + supportDataType(keyType, isReadPath) && supportDataType(valueType, isReadPath) + + case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + + case _: NullType => isReadPath + + case _ => false + } } private[orc] class OrcSerializer(dataSchema: StructType, conf: Configuration) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala index 69009e1b520c..fb4957ed943a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala @@ -146,38 +146,31 @@ class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { }.getMessage assert(msg.contains("Cannot save interval data type into external storage.")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { sql("select null").write.mode("overwrite").orc(orcDir) }.getMessage assert(msg.contains("ORC data source does not support null data type.")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { spark.udf.register("testType", () => new IntervalData()) sql("select testType()").write.mode("overwrite").orc(orcDir) }.getMessage - assert(msg.contains("ORC data source does not support calendarinterval data type.")) + assert(msg.contains("ORC data source does not support interval data type.")) // read path - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil) spark.range(1).write.mode("overwrite").orc(orcDir) spark.read.schema(schema).orc(orcDir).collect() }.getMessage assert(msg.contains("ORC data source does not support calendarinterval data type.")) - msg = intercept[UnsupportedOperationException] { - val schema = StructType(StructField("a", NullType, true) :: Nil) - spark.range(1).write.mode("overwrite").orc(orcDir) - spark.read.schema(schema).orc(orcDir).collect() - }.getMessage - assert(msg.contains("ORC data source does not support null data type.")) - - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil) spark.range(1).write.mode("overwrite").orc(orcDir) spark.read.schema(schema).orc(orcDir).collect() }.getMessage - assert(msg.contains("ORC data source does not support calendarinterval data type.")) + assert(msg.contains("ORC data source does not support interval data type.")) } } }