From 085e9a3453e235a8bfc7d63bc39c6ce28298d635 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 28 Mar 2018 23:00:09 +0900 Subject: [PATCH 1/9] WIP --- .../sql/catalyst/encoders/RowEncoder.scala | 4 +-- .../expressions/codegen/CodeGenerator.scala | 8 ++--- .../spark/sql/catalyst/json/JSONOptions.scala | 12 +++++-- .../apache/spark/sql/internal/SQLConf.scala | 11 ++++++ .../apache/spark/sql/types/StringType.scala | 2 +- .../spark/sql/types/TypePlaceholder.scala | 23 ++++++++++++ .../apache/spark/sql/DataFrameReader.scala | 3 +- .../execution/datasources/DataSource.scala | 2 +- .../datasources/json/JsonFileFormat.scala | 12 ++++--- .../datasources/json/JsonInferSchema.scala | 20 +++++++++-- .../streaming/FileStreamSource.scala | 29 ++++++++++++--- .../sql/streaming/FileStreamSourceSuite.scala | 36 +++++++++++++++++++ .../spark/sql/streaming/StreamTest.scala | 24 +++++++++++++ 13 files changed, 163 insertions(+), 23 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/TypePlaceholder.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 3340789398f9c..89405b87a1a27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -115,7 +115,7 @@ object RowEncoder { inputObject :: Nil, returnNullable = false) - case StringType => + case _: StringType => StaticInvoke( classOf[UTF8String], StringType, @@ -291,7 +291,7 @@ object RowEncoder { Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]), returnNullable = false) - case StringType => + case _: StringType => Invoke(input, "toString", ObjectType(classOf[String]), returnNullable = false) case ArrayType(et, nullable) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index cf0a91ff00626..fabd2f8ced136 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -329,7 +329,7 @@ class CodegenContext { def addBufferedState(dataType: DataType, variableName: String, initCode: String): ExprCode = { val value = addMutableState(javaType(dataType), variableName) val code = dataType match { - case StringType => s"$value = $initCode.clone();" + case _: StringType => s"$value = $initCode.clone();" case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();" case _ => s"$value = $initCode;" } @@ -1363,7 +1363,7 @@ object CodeGenerator extends Logging { dataType match { case _ if isPrimitiveType(jt) => s"$input.get${primitiveTypeName(jt)}($ordinal)" case t: DecimalType => s"$input.getDecimal($ordinal, ${t.precision}, ${t.scale})" - case StringType => s"$input.getUTF8String($ordinal)" + case _: StringType => s"$input.getUTF8String($ordinal)" case BinaryType => s"$input.getBinary($ordinal)" case CalendarIntervalType => s"$input.getInterval($ordinal)" case t: StructType => s"$input.getStruct($ordinal, ${t.size})" @@ -1386,7 +1386,7 @@ object CodeGenerator extends Logging { case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value) // The UTF8String, InternalRow, ArrayData and MapData may came from UnsafeRow, we should copy // it to avoid keeping a "pointer" to a memory region which may get updated afterwards. - case StringType | _: StructType | _: ArrayType | _: MapType => + case _: StringType | _: StructType | _: ArrayType | _: MapType => s"$row.update($ordinal, $value.copy())" case _ => s"$row.update($ordinal, $value)" } @@ -1502,7 +1502,7 @@ object CodeGenerator extends Logging { case DoubleType => JAVA_DOUBLE case _: DecimalType => "Decimal" case BinaryType => "byte[]" - case StringType => "UTF8String" + case _: StringType => "UTF8String" case CalendarIntervalType => "CalendarInterval" case _: StructType => "InternalRow" case _: ArrayType => "ArrayData" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 5f130af606e19..4f869acd4f5e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -34,17 +34,20 @@ import org.apache.spark.sql.catalyst.util._ private[sql] class JSONOptions( @transient private val parameters: CaseInsensitiveMap[String], defaultTimeZoneId: String, - defaultColumnNameOfCorruptRecord: String) + defaultColumnNameOfCorruptRecord: String, + defaultIgnoreNullFieldsInStreamingSchemaInference: Boolean) extends Logging with Serializable { def this( parameters: Map[String, String], defaultTimeZoneId: String, - defaultColumnNameOfCorruptRecord: String = "") = { + defaultColumnNameOfCorruptRecord: String = "", + defaultIgnoreNullFieldsInStreamingSchemaInference: Boolean = false) = { this( CaseInsensitiveMap(parameters), defaultTimeZoneId, - defaultColumnNameOfCorruptRecord) + defaultColumnNameOfCorruptRecord, + defaultIgnoreNullFieldsInStreamingSchemaInference) } val samplingRatio = @@ -72,6 +75,9 @@ private[sql] class JSONOptions( parameters.get("mode").map(ParseMode.fromString).getOrElse(PermissiveMode) val columnNameOfCorruptRecord = parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) + val ignoreNullFieldsInStreamingSchemaInference = + parameters.get("ignoreNullFieldsInStreamingSchemaInference").map(_.toBoolean) + .getOrElse(defaultIgnoreNullFieldsInStreamingSchemaInference) val timeZone: TimeZone = DateTimeUtils.getTimeZone( parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 895e150756567..f66cbd1a89275 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -905,6 +905,14 @@ object SQLConf { .booleanConf .createWithDefault(false) + val IGNORE_NULL_FIELDS_STREAMING_SCHEMA_INFERENCE = + buildConf("spark.sql.streaming.schemaInference.ignoreNullFields") + .internal() + .doc("Whether file-based streaming sources will ignore column of all null values or " + + "empty array during JSON schema inference") + .booleanConf + .createWithDefault(false) + val STREAMING_POLLING_DELAY = buildConf("spark.sql.streaming.pollingDelay") .internal() @@ -1316,6 +1324,9 @@ class SQLConf extends Serializable with Logging { def streamingSchemaInference: Boolean = getConf(STREAMING_SCHEMA_INFERENCE) + def ignoreNullFieldsInStreamingSchemaInference: Boolean = + getConf(IGNORE_NULL_FIELDS_STREAMING_SCHEMA_INFERENCE) + def streamingPollingDelay: Long = getConf(STREAMING_POLLING_DELAY) def streamingNoDataProgressEventInterval: Long = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala index 59b124cda7d14..d45d41c654212 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -29,7 +29,7 @@ import org.apache.spark.unsafe.types.UTF8String * @since 1.3.0 */ @InterfaceStability.Stable -class StringType private() extends AtomicType { +class StringType private[types]() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "StringType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TypePlaceholder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TypePlaceholder.scala new file mode 100644 index 0000000000000..9b353d2f004fd --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TypePlaceholder.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +/** + * An internal type that is a not yet available and will be replaced by an actual type later. + */ +case object TypePlaceholder extends StringType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 53f44888ebaff..3fbd6f2df5ff3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -429,7 +429,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val parsedOptions = new JSONOptions( extraOptions.toMap, sparkSession.sessionState.conf.sessionLocalTimeZone, - sparkSession.sessionState.conf.columnNameOfCorruptRecord) + sparkSession.sessionState.conf.columnNameOfCorruptRecord, + sparkSession.sessionState.conf.ignoreNullFieldsInStreamingSchemaInference) val schema = userSpecifiedSchema.getOrElse { TextInputJsonDataSource.inferFromDataset(jsonDataset, parsedOptions) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index f16d824201e77..fed857d26cf0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -266,7 +266,7 @@ case class DataSource( sparkSession = sparkSession, path = path, fileFormatClassName = className, - schema = sourceInfo.schema, + initialSchema = sourceInfo.schema, partitionColumns = sourceInfo.partitionColumns, metadataPath = metadataPath, options = caseInsensitiveOptions) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 3b04510d29695..7e3eaae3d4295 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -43,7 +43,8 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { val parsedOptions = new JSONOptions( options, sparkSession.sessionState.conf.sessionLocalTimeZone, - sparkSession.sessionState.conf.columnNameOfCorruptRecord) + sparkSession.sessionState.conf.columnNameOfCorruptRecord, + sparkSession.sessionState.conf.ignoreNullFieldsInStreamingSchemaInference) val jsonDataSource = JsonDataSource(parsedOptions) jsonDataSource.isSplitable && super.isSplitable(sparkSession, options, path) } @@ -55,7 +56,8 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { val parsedOptions = new JSONOptions( options, sparkSession.sessionState.conf.sessionLocalTimeZone, - sparkSession.sessionState.conf.columnNameOfCorruptRecord) + sparkSession.sessionState.conf.columnNameOfCorruptRecord, + sparkSession.sessionState.conf.ignoreNullFieldsInStreamingSchemaInference) JsonDataSource(parsedOptions).inferSchema( sparkSession, files, parsedOptions) } @@ -69,7 +71,8 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { val parsedOptions = new JSONOptions( options, sparkSession.sessionState.conf.sessionLocalTimeZone, - sparkSession.sessionState.conf.columnNameOfCorruptRecord) + sparkSession.sessionState.conf.columnNameOfCorruptRecord, + sparkSession.sessionState.conf.ignoreNullFieldsInStreamingSchemaInference) parsedOptions.compressionCodec.foreach { codec => CompressionCodecs.setCodecConfiguration(conf, codec) } @@ -102,7 +105,8 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { val parsedOptions = new JSONOptions( options, sparkSession.sessionState.conf.sessionLocalTimeZone, - sparkSession.sessionState.conf.columnNameOfCorruptRecord) + sparkSession.sessionState.conf.columnNameOfCorruptRecord, + sparkSession.sessionState.conf.ignoreNullFieldsInStreamingSchemaInference) val actualSchema = StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index a270a6451d5dd..55f25568c6b27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -101,6 +101,10 @@ private[sql] object JsonInferSchema { private def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = { import com.fasterxml.jackson.core.JsonToken._ parser.getCurrentToken match { + // TODO: We need to check if this is a streaming mode + case null | VALUE_NULL if configOptions.ignoreNullFieldsInStreamingSchemaInference => + TypePlaceholder + case null | VALUE_NULL => NullType case FIELD_NAME => @@ -131,10 +135,18 @@ private[sql] object JsonInferSchema { StructType(fields) case START_ARRAY => - // If this JSON array is empty, we use NullType as a placeholder. + // If this JSON array is empty and `ignoreNullFieldsInStreamingSchemaInference` is true, + // we use `TypePlaceholder` (This type will be resolved by using a first-encountered + // value in incoming batches. + var elementType: DataType = if (configOptions.ignoreNullFieldsInStreamingSchemaInference) { + // TODO: We need to check if this is a streaming mode + TypePlaceholder + } else { + // Otherwise, we use `NullType` as a placeholder + NullType + } // If this array is not empty in other JSON objects, we can resolve // the type as we pass through all JSON objects. - var elementType: DataType = NullType while (nextUntil(parser, END_ARRAY)) { elementType = compatibleType( elementType, inferField(parser, configOptions)) @@ -336,6 +348,10 @@ private[sql] object JsonInferSchema { case (t1: DecimalType, t2: IntegralType) => compatibleType(t1, DecimalType.forType(t2)) + // If `TypePlaceholder` found, return the other type + case (t1, TypePlaceholder) => t1 + case (TypePlaceholder, t2) => t2 + // strings and every string is a Json object. case (_, _) => StringType } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 8c016abc5b643..59ce7b0880fbf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -27,7 +27,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.execution.datasources.{DataSource, InMemoryFileIndex, LogicalRelation} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StructType, TypePlaceholder} /** * A very simple source that reads files from the given directory as they appear. @@ -36,13 +36,20 @@ class FileStreamSource( sparkSession: SparkSession, path: String, fileFormatClassName: String, - override val schema: StructType, + initialSchema: StructType, partitionColumns: Seq[String], metadataPath: String, options: Map[String, String]) extends Source with Logging { import FileStreamSource._ + private val ignoreNullFieldsInSchemaInference = + sparkSession.sqlContext.conf.ignoreNullFieldsInStreamingSchemaInference + + private var _schema: Option[StructType] = None + + override def schema: StructType = _schema.getOrElse(initialSchema) + private val sourceOptions = new FileStreamOptions(options) private val hadoopConf = sparkSession.sessionState.newHadoopConf() @@ -163,16 +170,28 @@ class FileStreamSource( val files = metadataLog.get(Some(startOffset + 1), Some(endOffset)).flatMap(_._2) logInfo(s"Processing ${files.length} files from ${startOffset + 1}:$endOffset") logTrace(s"Files are:\n\t" + files.mkString("\n\t")) + + // If `ignoreNullFieldsInSchemaInference` is true and unresolved types exist in `schema`, + // we trigger schema inference in the current batch. + val doInferSchema = ignoreNullFieldsInSchemaInference && + schema.existsRecursively(_.acceptsType(TypePlaceholder)) + val newDataSource = DataSource( sparkSession, paths = files.map(f => new Path(new URI(f.path)).toString), - userSpecifiedSchema = Some(schema), + userSpecifiedSchema = if (!doInferSchema) Some(schema) else None, partitionColumns = partitionColumns, className = fileFormatClassName, options = optionsWithPartitionBasePath) - Dataset.ofRows(sparkSession, LogicalRelation(newDataSource.resolveRelation( - checkFilesExist = false), isStreaming = true)) + + val rel = newDataSource.resolveRelation(checkFilesExist = false) + // If schema inference triggered in the batch, replace the current `schema` with new one + if (doInferSchema) { + _schema = Some(rel.schema) + } + + Dataset.ofRows(sparkSession, LogicalRelation(rel, isStreaming = true)) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index d4bd9c7987f2d..f92f94d358341 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -624,6 +624,42 @@ class FileStreamSourceSuite extends FileStreamSourceTest { } } + test("SPARK-23772 Ignore column of all null values or empty array during JSON schema inference") { + withTempDirs { case (src, tmp) => + withSQLConf( + SQLConf.STREAMING_SCHEMA_INFERENCE.key -> "true", + SQLConf.IGNORE_NULL_FIELDS_STREAMING_SCHEMA_INFERENCE.key -> "true") { + + // Add a file so that we can infer its schema + stringToFile(new File(src, "existing"), "{'c0': 1, 'c1': null, 'c2': []}") + + val fileStream = createFileStream("json", src.getCanonicalPath) + + // FileStreamSource should infer the column "k" + assert(fileStream.schema === StructType( + StructField("c0", LongType) :: + StructField("c1", TypePlaceholder) :: + StructField("c2", ArrayType(TypePlaceholder)) :: Nil)) + + testStream(fileStream)( + + // Should not pick up column v in the file added before start + AddTextFileData("{'c0': 3, 'c1': 3.8, 'c2': []}", src, tmp), + CheckSchema(StructType( + StructField("c0", LongType) :: + StructField("c1", DoubleType) :: + StructField("c2", ArrayType(TypePlaceholder)) :: Nil)), + + // Should read data in column k, and ignore v + AddTextFileData("{'c0': 2, 'c1': 1.1, 'c2': [1, 2, 3]}", src, tmp), + CheckSchema(StructType( + StructField("c0", LongType) :: + StructField("c1", DoubleType) :: + StructField("c2", ArrayType(LongType)) :: Nil))) + } + } + } + // =============== ORC file stream tests ================ test("read from orc files") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 9d139a927bea5..00958e650778b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -47,6 +47,7 @@ import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.StructType import org.apache.spark.util.{Clock, SystemClock, Utils} /** @@ -220,6 +221,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } } + case class CheckSchema(expectedSchema: StructType) extends StreamAction with StreamMustBeRunning + /** Stops the stream. It must currently be running. */ case object StopStream extends StreamAction with StreamMustBeRunning @@ -746,6 +749,27 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be QueryTest.sameRows(expectedAnswer, sparkAnswer).foreach { error => failTest(error) } + + case CheckSchema(expectedSchema) => + // waitAllDataProcessed(currentStream) + val resultSchema = currentStream.lastExecution.analyzed.schema + if (expectedSchema != resultSchema) { + failTest( + s""" + |== Results == + |${ + sideBySide( + s""" + |== Correct Schema == + |${expectedSchema.simpleString} + """.stripMargin, + s""" + |== Spark Result Schema == + |${resultSchema.simpleString} + """.stripMargin).mkString("\n") + } + """.stripMargin) + } } pos += 1 } From 8c757815b4759f0dd7c685daa91acf9f0aaff615 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 7 May 2018 23:58:23 +0900 Subject: [PATCH 2/9] Revert "WIP" This reverts commit 085e9a3453e235a8bfc7d63bc39c6ce28298d635. --- .../sql/catalyst/encoders/RowEncoder.scala | 4 +-- .../expressions/codegen/CodeGenerator.scala | 8 ++--- .../spark/sql/catalyst/json/JSONOptions.scala | 12 ++----- .../apache/spark/sql/internal/SQLConf.scala | 11 ------ .../apache/spark/sql/types/StringType.scala | 2 +- .../spark/sql/types/TypePlaceholder.scala | 23 ------------ .../apache/spark/sql/DataFrameReader.scala | 3 +- .../execution/datasources/DataSource.scala | 2 +- .../datasources/json/JsonFileFormat.scala | 12 +++---- .../datasources/json/JsonInferSchema.scala | 20 ++--------- .../streaming/FileStreamSource.scala | 29 +++------------ .../sql/streaming/FileStreamSourceSuite.scala | 36 ------------------- .../spark/sql/streaming/StreamTest.scala | 24 ------------- 13 files changed, 23 insertions(+), 163 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/TypePlaceholder.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 89405b87a1a27..3340789398f9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -115,7 +115,7 @@ object RowEncoder { inputObject :: Nil, returnNullable = false) - case _: StringType => + case StringType => StaticInvoke( classOf[UTF8String], StringType, @@ -291,7 +291,7 @@ object RowEncoder { Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]), returnNullable = false) - case _: StringType => + case StringType => Invoke(input, "toString", ObjectType(classOf[String]), returnNullable = false) case ArrayType(et, nullable) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index fabd2f8ced136..cf0a91ff00626 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -329,7 +329,7 @@ class CodegenContext { def addBufferedState(dataType: DataType, variableName: String, initCode: String): ExprCode = { val value = addMutableState(javaType(dataType), variableName) val code = dataType match { - case _: StringType => s"$value = $initCode.clone();" + case StringType => s"$value = $initCode.clone();" case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();" case _ => s"$value = $initCode;" } @@ -1363,7 +1363,7 @@ object CodeGenerator extends Logging { dataType match { case _ if isPrimitiveType(jt) => s"$input.get${primitiveTypeName(jt)}($ordinal)" case t: DecimalType => s"$input.getDecimal($ordinal, ${t.precision}, ${t.scale})" - case _: StringType => s"$input.getUTF8String($ordinal)" + case StringType => s"$input.getUTF8String($ordinal)" case BinaryType => s"$input.getBinary($ordinal)" case CalendarIntervalType => s"$input.getInterval($ordinal)" case t: StructType => s"$input.getStruct($ordinal, ${t.size})" @@ -1386,7 +1386,7 @@ object CodeGenerator extends Logging { case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value) // The UTF8String, InternalRow, ArrayData and MapData may came from UnsafeRow, we should copy // it to avoid keeping a "pointer" to a memory region which may get updated afterwards. - case _: StringType | _: StructType | _: ArrayType | _: MapType => + case StringType | _: StructType | _: ArrayType | _: MapType => s"$row.update($ordinal, $value.copy())" case _ => s"$row.update($ordinal, $value)" } @@ -1502,7 +1502,7 @@ object CodeGenerator extends Logging { case DoubleType => JAVA_DOUBLE case _: DecimalType => "Decimal" case BinaryType => "byte[]" - case _: StringType => "UTF8String" + case StringType => "UTF8String" case CalendarIntervalType => "CalendarInterval" case _: StructType => "InternalRow" case _: ArrayType => "ArrayData" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 4f869acd4f5e6..5f130af606e19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -34,20 +34,17 @@ import org.apache.spark.sql.catalyst.util._ private[sql] class JSONOptions( @transient private val parameters: CaseInsensitiveMap[String], defaultTimeZoneId: String, - defaultColumnNameOfCorruptRecord: String, - defaultIgnoreNullFieldsInStreamingSchemaInference: Boolean) + defaultColumnNameOfCorruptRecord: String) extends Logging with Serializable { def this( parameters: Map[String, String], defaultTimeZoneId: String, - defaultColumnNameOfCorruptRecord: String = "", - defaultIgnoreNullFieldsInStreamingSchemaInference: Boolean = false) = { + defaultColumnNameOfCorruptRecord: String = "") = { this( CaseInsensitiveMap(parameters), defaultTimeZoneId, - defaultColumnNameOfCorruptRecord, - defaultIgnoreNullFieldsInStreamingSchemaInference) + defaultColumnNameOfCorruptRecord) } val samplingRatio = @@ -75,9 +72,6 @@ private[sql] class JSONOptions( parameters.get("mode").map(ParseMode.fromString).getOrElse(PermissiveMode) val columnNameOfCorruptRecord = parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) - val ignoreNullFieldsInStreamingSchemaInference = - parameters.get("ignoreNullFieldsInStreamingSchemaInference").map(_.toBoolean) - .getOrElse(defaultIgnoreNullFieldsInStreamingSchemaInference) val timeZone: TimeZone = DateTimeUtils.getTimeZone( parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index f66cbd1a89275..895e150756567 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -905,14 +905,6 @@ object SQLConf { .booleanConf .createWithDefault(false) - val IGNORE_NULL_FIELDS_STREAMING_SCHEMA_INFERENCE = - buildConf("spark.sql.streaming.schemaInference.ignoreNullFields") - .internal() - .doc("Whether file-based streaming sources will ignore column of all null values or " + - "empty array during JSON schema inference") - .booleanConf - .createWithDefault(false) - val STREAMING_POLLING_DELAY = buildConf("spark.sql.streaming.pollingDelay") .internal() @@ -1324,9 +1316,6 @@ class SQLConf extends Serializable with Logging { def streamingSchemaInference: Boolean = getConf(STREAMING_SCHEMA_INFERENCE) - def ignoreNullFieldsInStreamingSchemaInference: Boolean = - getConf(IGNORE_NULL_FIELDS_STREAMING_SCHEMA_INFERENCE) - def streamingPollingDelay: Long = getConf(STREAMING_POLLING_DELAY) def streamingNoDataProgressEventInterval: Long = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala index d45d41c654212..59b124cda7d14 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -29,7 +29,7 @@ import org.apache.spark.unsafe.types.UTF8String * @since 1.3.0 */ @InterfaceStability.Stable -class StringType private[types]() extends AtomicType { +class StringType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "StringType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TypePlaceholder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TypePlaceholder.scala deleted file mode 100644 index 9b353d2f004fd..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TypePlaceholder.scala +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.types - -/** - * An internal type that is a not yet available and will be replaced by an actual type later. - */ -case object TypePlaceholder extends StringType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 3fbd6f2df5ff3..53f44888ebaff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -429,8 +429,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val parsedOptions = new JSONOptions( extraOptions.toMap, sparkSession.sessionState.conf.sessionLocalTimeZone, - sparkSession.sessionState.conf.columnNameOfCorruptRecord, - sparkSession.sessionState.conf.ignoreNullFieldsInStreamingSchemaInference) + sparkSession.sessionState.conf.columnNameOfCorruptRecord) val schema = userSpecifiedSchema.getOrElse { TextInputJsonDataSource.inferFromDataset(jsonDataset, parsedOptions) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index fed857d26cf0a..f16d824201e77 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -266,7 +266,7 @@ case class DataSource( sparkSession = sparkSession, path = path, fileFormatClassName = className, - initialSchema = sourceInfo.schema, + schema = sourceInfo.schema, partitionColumns = sourceInfo.partitionColumns, metadataPath = metadataPath, options = caseInsensitiveOptions) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 7e3eaae3d4295..3b04510d29695 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -43,8 +43,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { val parsedOptions = new JSONOptions( options, sparkSession.sessionState.conf.sessionLocalTimeZone, - sparkSession.sessionState.conf.columnNameOfCorruptRecord, - sparkSession.sessionState.conf.ignoreNullFieldsInStreamingSchemaInference) + sparkSession.sessionState.conf.columnNameOfCorruptRecord) val jsonDataSource = JsonDataSource(parsedOptions) jsonDataSource.isSplitable && super.isSplitable(sparkSession, options, path) } @@ -56,8 +55,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { val parsedOptions = new JSONOptions( options, sparkSession.sessionState.conf.sessionLocalTimeZone, - sparkSession.sessionState.conf.columnNameOfCorruptRecord, - sparkSession.sessionState.conf.ignoreNullFieldsInStreamingSchemaInference) + sparkSession.sessionState.conf.columnNameOfCorruptRecord) JsonDataSource(parsedOptions).inferSchema( sparkSession, files, parsedOptions) } @@ -71,8 +69,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { val parsedOptions = new JSONOptions( options, sparkSession.sessionState.conf.sessionLocalTimeZone, - sparkSession.sessionState.conf.columnNameOfCorruptRecord, - sparkSession.sessionState.conf.ignoreNullFieldsInStreamingSchemaInference) + sparkSession.sessionState.conf.columnNameOfCorruptRecord) parsedOptions.compressionCodec.foreach { codec => CompressionCodecs.setCodecConfiguration(conf, codec) } @@ -105,8 +102,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { val parsedOptions = new JSONOptions( options, sparkSession.sessionState.conf.sessionLocalTimeZone, - sparkSession.sessionState.conf.columnNameOfCorruptRecord, - sparkSession.sessionState.conf.ignoreNullFieldsInStreamingSchemaInference) + sparkSession.sessionState.conf.columnNameOfCorruptRecord) val actualSchema = StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index 55f25568c6b27..a270a6451d5dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -101,10 +101,6 @@ private[sql] object JsonInferSchema { private def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = { import com.fasterxml.jackson.core.JsonToken._ parser.getCurrentToken match { - // TODO: We need to check if this is a streaming mode - case null | VALUE_NULL if configOptions.ignoreNullFieldsInStreamingSchemaInference => - TypePlaceholder - case null | VALUE_NULL => NullType case FIELD_NAME => @@ -135,18 +131,10 @@ private[sql] object JsonInferSchema { StructType(fields) case START_ARRAY => - // If this JSON array is empty and `ignoreNullFieldsInStreamingSchemaInference` is true, - // we use `TypePlaceholder` (This type will be resolved by using a first-encountered - // value in incoming batches. - var elementType: DataType = if (configOptions.ignoreNullFieldsInStreamingSchemaInference) { - // TODO: We need to check if this is a streaming mode - TypePlaceholder - } else { - // Otherwise, we use `NullType` as a placeholder - NullType - } + // If this JSON array is empty, we use NullType as a placeholder. // If this array is not empty in other JSON objects, we can resolve // the type as we pass through all JSON objects. + var elementType: DataType = NullType while (nextUntil(parser, END_ARRAY)) { elementType = compatibleType( elementType, inferField(parser, configOptions)) @@ -348,10 +336,6 @@ private[sql] object JsonInferSchema { case (t1: DecimalType, t2: IntegralType) => compatibleType(t1, DecimalType.forType(t2)) - // If `TypePlaceholder` found, return the other type - case (t1, TypePlaceholder) => t1 - case (TypePlaceholder, t2) => t2 - // strings and every string is a Json object. case (_, _) => StringType } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 59ce7b0880fbf..8c016abc5b643 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -27,7 +27,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.execution.datasources.{DataSource, InMemoryFileIndex, LogicalRelation} -import org.apache.spark.sql.types.{StructType, TypePlaceholder} +import org.apache.spark.sql.types.StructType /** * A very simple source that reads files from the given directory as they appear. @@ -36,20 +36,13 @@ class FileStreamSource( sparkSession: SparkSession, path: String, fileFormatClassName: String, - initialSchema: StructType, + override val schema: StructType, partitionColumns: Seq[String], metadataPath: String, options: Map[String, String]) extends Source with Logging { import FileStreamSource._ - private val ignoreNullFieldsInSchemaInference = - sparkSession.sqlContext.conf.ignoreNullFieldsInStreamingSchemaInference - - private var _schema: Option[StructType] = None - - override def schema: StructType = _schema.getOrElse(initialSchema) - private val sourceOptions = new FileStreamOptions(options) private val hadoopConf = sparkSession.sessionState.newHadoopConf() @@ -170,28 +163,16 @@ class FileStreamSource( val files = metadataLog.get(Some(startOffset + 1), Some(endOffset)).flatMap(_._2) logInfo(s"Processing ${files.length} files from ${startOffset + 1}:$endOffset") logTrace(s"Files are:\n\t" + files.mkString("\n\t")) - - // If `ignoreNullFieldsInSchemaInference` is true and unresolved types exist in `schema`, - // we trigger schema inference in the current batch. - val doInferSchema = ignoreNullFieldsInSchemaInference && - schema.existsRecursively(_.acceptsType(TypePlaceholder)) - val newDataSource = DataSource( sparkSession, paths = files.map(f => new Path(new URI(f.path)).toString), - userSpecifiedSchema = if (!doInferSchema) Some(schema) else None, + userSpecifiedSchema = Some(schema), partitionColumns = partitionColumns, className = fileFormatClassName, options = optionsWithPartitionBasePath) - - val rel = newDataSource.resolveRelation(checkFilesExist = false) - // If schema inference triggered in the batch, replace the current `schema` with new one - if (doInferSchema) { - _schema = Some(rel.schema) - } - - Dataset.ofRows(sparkSession, LogicalRelation(rel, isStreaming = true)) + Dataset.ofRows(sparkSession, LogicalRelation(newDataSource.resolveRelation( + checkFilesExist = false), isStreaming = true)) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index f92f94d358341..d4bd9c7987f2d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -624,42 +624,6 @@ class FileStreamSourceSuite extends FileStreamSourceTest { } } - test("SPARK-23772 Ignore column of all null values or empty array during JSON schema inference") { - withTempDirs { case (src, tmp) => - withSQLConf( - SQLConf.STREAMING_SCHEMA_INFERENCE.key -> "true", - SQLConf.IGNORE_NULL_FIELDS_STREAMING_SCHEMA_INFERENCE.key -> "true") { - - // Add a file so that we can infer its schema - stringToFile(new File(src, "existing"), "{'c0': 1, 'c1': null, 'c2': []}") - - val fileStream = createFileStream("json", src.getCanonicalPath) - - // FileStreamSource should infer the column "k" - assert(fileStream.schema === StructType( - StructField("c0", LongType) :: - StructField("c1", TypePlaceholder) :: - StructField("c2", ArrayType(TypePlaceholder)) :: Nil)) - - testStream(fileStream)( - - // Should not pick up column v in the file added before start - AddTextFileData("{'c0': 3, 'c1': 3.8, 'c2': []}", src, tmp), - CheckSchema(StructType( - StructField("c0", LongType) :: - StructField("c1", DoubleType) :: - StructField("c2", ArrayType(TypePlaceholder)) :: Nil)), - - // Should read data in column k, and ignore v - AddTextFileData("{'c0': 2, 'c1': 1.1, 'c2': [1, 2, 3]}", src, tmp), - CheckSchema(StructType( - StructField("c0", LongType) :: - StructField("c1", DoubleType) :: - StructField("c2", ArrayType(LongType)) :: Nil))) - } - } - } - // =============== ORC file stream tests ================ test("read from orc files") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 00958e650778b..9d139a927bea5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -47,7 +47,6 @@ import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.StructType import org.apache.spark.util.{Clock, SystemClock, Utils} /** @@ -221,8 +220,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } } - case class CheckSchema(expectedSchema: StructType) extends StreamAction with StreamMustBeRunning - /** Stops the stream. It must currently be running. */ case object StopStream extends StreamAction with StreamMustBeRunning @@ -749,27 +746,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be QueryTest.sameRows(expectedAnswer, sparkAnswer).foreach { error => failTest(error) } - - case CheckSchema(expectedSchema) => - // waitAllDataProcessed(currentStream) - val resultSchema = currentStream.lastExecution.analyzed.schema - if (expectedSchema != resultSchema) { - failTest( - s""" - |== Results == - |${ - sideBySide( - s""" - |== Correct Schema == - |${expectedSchema.simpleString} - """.stripMargin, - s""" - |== Spark Result Schema == - |${resultSchema.simpleString} - """.stripMargin).mkString("\n") - } - """.stripMargin) - } } pos += 1 } From 53b686dede4e5fbcb2b3e39932602ae0c9974209 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 7 May 2018 20:22:23 +0900 Subject: [PATCH 3/9] Fix --- .../spark/sql/catalyst/json/JSONOptions.scala | 12 ++++ .../sql/catalyst/json/JacksonParser.scala | 9 ++- .../execution/datasources/DataSource.scala | 2 +- .../datasources/json/JsonFileFormat.scala | 14 +++-- .../datasources/json/JsonInferSchema.scala | 10 ++-- .../streaming/FileStreamSource.scala | 28 +++++++--- .../sql/streaming/FileStreamSourceSuite.scala | 55 ++++++++++++++++++- .../spark/sql/streaming/StreamTest.scala | 42 +++++++++++--- 8 files changed, 142 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 5f130af606e19..872c0c09a064f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -73,6 +73,18 @@ private[sql] class JSONOptions( val columnNameOfCorruptRecord = parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) + // Whether file-based streaming sources will ignore column of all null values or empty + // array during JSON schema inference. + val dropFieldIfAllNull = { + val streamingSchemaInference = + parameters.get("streamingSchemaInference").map(_.toBoolean).getOrElse(false) + val _dropFieldIfAllNull = + parameters.get("dropFieldIfAllNull").map(_.toBoolean).getOrElse(false) + + // We could set true at `dropFieldIfAllNull` iff `spark.sql.streaming.schemaInference` enabled + streamingSchemaInference && _dropFieldIfAllNull + } + val timeZone: TimeZone = DateTimeUtils.getTimeZone( parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index a5a4a13eb608b..61753bf241a33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -329,8 +329,13 @@ class JacksonParser( while (nextUntil(parser, JsonToken.END_ARRAY)) { values += fieldConverter.apply(parser) } - - new GenericArrayData(values.toArray) + // Canonicalize arrays; an array is null if all its elements are null + // TODO: Reconsider this + if (options.dropFieldIfAllNull && values.forall(_ == null)) { + null + } else { + new GenericArrayData(values.toArray) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index f16d824201e77..fed857d26cf0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -266,7 +266,7 @@ case class DataSource( sparkSession = sparkSession, path = path, fileFormatClassName = className, - schema = sourceInfo.schema, + initialSchema = sourceInfo.schema, partitionColumns = sourceInfo.partitionColumns, metadataPath = metadataPath, options = caseInsensitiveOptions) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 3b04510d29695..d7ecb78bd7e84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -52,10 +52,11 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { + val sqlConf = sparkSession.sessionState.conf val parsedOptions = new JSONOptions( - options, - sparkSession.sessionState.conf.sessionLocalTimeZone, - sparkSession.sessionState.conf.columnNameOfCorruptRecord) + options + ("streamingSchemaInference" -> sqlConf.streamingSchemaInference.toString), + sqlConf.sessionLocalTimeZone, + sqlConf.columnNameOfCorruptRecord) JsonDataSource(parsedOptions).inferSchema( sparkSession, files, parsedOptions) } @@ -99,10 +100,11 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + val sqlConf = sparkSession.sessionState.conf val parsedOptions = new JSONOptions( - options, - sparkSession.sessionState.conf.sessionLocalTimeZone, - sparkSession.sessionState.conf.columnNameOfCorruptRecord) + options + ("streamingSchemaInference" -> sqlConf.streamingSchemaInference.toString), + sqlConf.sessionLocalTimeZone, + sqlConf.columnNameOfCorruptRecord) val actualSchema = StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index a270a6451d5dd..ec441a5ba1a46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -70,7 +70,7 @@ private[sql] object JsonInferSchema { }.fold(StructType(Nil))( compatibleRootType(columnNameOfCorruptRecord, parseMode)) - canonicalizeType(rootType) match { + canonicalizeType(rootType, configOptions) match { case Some(st: StructType) => st case _ => // canonicalizeType erases all empty structs, including the only one we want to keep @@ -178,10 +178,10 @@ private[sql] object JsonInferSchema { /** * Convert NullType to StringType and remove StructTypes with no fields */ - private def canonicalizeType(tpe: DataType): Option[DataType] = tpe match { + private def canonicalizeType(tpe: DataType, options: JSONOptions): Option[DataType] = tpe match { case at @ ArrayType(elementType, _) => for { - canonicalType <- canonicalizeType(elementType) + canonicalType <- canonicalizeType(elementType, options) } yield { at.copy(canonicalType) } @@ -190,7 +190,7 @@ private[sql] object JsonInferSchema { val canonicalFields: Array[StructField] = for { field <- fields if field.name.length > 0 - canonicalType <- canonicalizeType(field.dataType) + canonicalType <- canonicalizeType(field.dataType, options) } yield { field.copy(dataType = canonicalType) } @@ -202,7 +202,7 @@ private[sql] object JsonInferSchema { None } - case NullType => Some(StringType) + case NullType if !options.dropFieldIfAllNull => Some(StringType) case other => Some(other) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 8c016abc5b643..3eed0745c8075 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -19,15 +19,13 @@ package org.apache.spark.sql.execution.streaming import java.net.URI -import scala.collection.JavaConverters._ - import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.execution.datasources.{DataSource, InMemoryFileIndex, LogicalRelation} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{NullType, StructType} /** * A very simple source that reads files from the given directory as they appear. @@ -36,13 +34,17 @@ class FileStreamSource( sparkSession: SparkSession, path: String, fileFormatClassName: String, - override val schema: StructType, + initialSchema: StructType, partitionColumns: Seq[String], metadataPath: String, options: Map[String, String]) extends Source with Logging { import FileStreamSource._ + private var _schema: Option[StructType] = None + + override def schema: StructType = _schema.getOrElse(initialSchema) + private val sourceOptions = new FileStreamOptions(options) private val hadoopConf = sparkSession.sessionState.newHadoopConf() @@ -163,16 +165,28 @@ class FileStreamSource( val files = metadataLog.get(Some(startOffset + 1), Some(endOffset)).flatMap(_._2) logInfo(s"Processing ${files.length} files from ${startOffset + 1}:$endOffset") logTrace(s"Files are:\n\t" + files.mkString("\n\t")) + + // If the current schema has `NullType`s, we will trigger schema inference again in + // the current batch. + val doInferSchemaInCurrentBatch = schema.existsRecursively(_.acceptsType(NullType)) + val newDataSource = DataSource( sparkSession, paths = files.map(f => new Path(new URI(f.path)).toString), - userSpecifiedSchema = Some(schema), + userSpecifiedSchema = if (!doInferSchemaInCurrentBatch) Some(schema) else None, partitionColumns = partitionColumns, className = fileFormatClassName, options = optionsWithPartitionBasePath) - Dataset.ofRows(sparkSession, LogicalRelation(newDataSource.resolveRelation( - checkFilesExist = false), isStreaming = true)) + + val rel = newDataSource.resolveRelation(checkFilesExist = false) + // If schema inference triggered in the current batch, replaces the current `schema` + // with the inferred one. + if (doInferSchemaInCurrentBatch) { + _schema = Some(rel.schema) + } + + Dataset.ofRows(sparkSession, LogicalRelation(rel, isStreaming = true)) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index d4bd9c7987f2d..ce25d9251f3c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -24,11 +24,11 @@ import scala.util.Random import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem} import org.scalatest.PrivateMethodTester -import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.FileStreamSource.{FileEntry, SeenFilesMap} import org.apache.spark.sql.internal.SQLConf @@ -624,6 +624,56 @@ class FileStreamSourceSuite extends FileStreamSourceTest { } } + test("SPARK-23772 Ignore column of all null values or empty array during JSON schema inference") { + withTempDirs { case (src, tmp) => + withSQLConf(SQLConf.STREAMING_SCHEMA_INFERENCE.key -> "true") { + + // Add a file so that we can infer its schema + stringToFile(new File(src, "existing"), + "{'c0': 1, 'c1': null, 'c2': []}\n{'c0': 2, 'c1': null, 'c2': null}") + + val fileStream = createFileStream( + "json", src.getCanonicalPath, options = Map("dropFieldIfAllNull" -> "true")) + + // FileStreamSource should infer the column "k" + assert(fileStream.schema === StructType( + StructField("c0", LongType) :: + StructField("c1", NullType) :: + StructField("c2", ArrayType(NullType)) :: Nil)) + + testStream(fileStream)( + + // Should not pick up column v in the file added before start + AddTextFileData("{'c0': 3, 'c1': 3.8, 'c2': [null, null]}", src, tmp), + CheckSchema(StructType( + StructField("c0", LongType) :: + StructField("c1", DoubleType) :: + StructField("c2", ArrayType(NullType)) :: Nil)), + // + // CheckAnswer(Row(1, null, Array()) :: Row(2, null, null) :: + // Row(3, 3.8, Array(null, null)) :: Nil: _*), + // + // Canonicalize arrays; an array is null if all its elements are null + CheckAnswer(Row(1, null, null) :: Row(2, null, null) :: Row(3, 3.8, null) + :: Nil: _*), + + // Should read data in column k, and ignore v + AddTextFileData("{'c0': 4, 'c1': 1.1, 'c2': [1, 2, 3]}", src, tmp), + CheckSchema(StructType( + StructField("c0", LongType) :: + StructField("c1", DoubleType) :: + StructField("c2", ArrayType(LongType)) :: Nil)), + // + // CheckAnswer(Row(1, null, Array()) :: Row(2, null, null) :: + // Row(3, 3.8, Array(null, null)) :: Row(4, 1.1, Array(1, 2, 3)) :: Nil: _*)) + // + // Canonicalize arrays; an array is null if all its elements are null + CheckAnswer(Row(1, null, null) :: Row(2, null, null) :: + Row(3, 3.8, null) :: Row(4, 1.1, Array(1, 2, 3)) :: Nil: _*)) + } + } + } + // =============== ORC file stream tests ================ test("read from orc files") { @@ -1115,7 +1165,8 @@ class FileStreamSourceSuite extends FileStreamSourceTest { val df = spark.readStream.format("text").load(src.getCanonicalPath).map(_ + "-x") // Test `explain` not throwing errors - df.explain() + val explainCmd = ExplainCommand(df.queryExecution.logical, extended = false) + spark.sessionState.executePlan(explainCmd).executedPlan val q = df.writeStream.queryName("file_explain").format("memory").start() .asInstanceOf[StreamingQueryWrapper] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 9d139a927bea5..65dc3e727e298 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -44,9 +44,9 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch} import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.StructType import org.apache.spark.util.{Clock, SystemClock, Utils} /** @@ -203,6 +203,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be private def operatorName = if (lastOnly) "CheckLastBatchByFunc" else "CheckAnswerByFunc" } + case class CheckSchema(expectedSchema: StructType) extends StreamAction with StreamMustBeRunning + case class CheckNewAnswerRows(expectedAnswer: Seq[Row]) extends StreamAction with StreamMustBeRunning { override def toString: String = s"$operatorName: ${expectedAnswer.mkString(",")}" @@ -455,12 +457,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be var lastFetchedMemorySinkLastBatchId: Long = -1 - def fetchStreamAnswer( - currentStream: StreamExecution, - lastOnly: Boolean = false, - sinceLastFetchOnly: Boolean = false) = { - verify( - !(lastOnly && sinceLastFetchOnly), "both lastOnly and sinceLastFetchOnly cannot be true") + def waitAllDataProcessed() = { verify(currentStream != null, "stream not running") // Block until all data added has been processed for all the source @@ -473,6 +470,16 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } } } + } + + def fetchStreamAnswer( + currentStream: StreamExecution, + lastOnly: Boolean = false, + sinceLastFetchOnly: Boolean = false) = { + verify( + !(lastOnly && sinceLastFetchOnly), "both lastOnly and sinceLastFetchOnly cannot be true") + + waitAllDataProcessed() val lastExecution = currentStream.lastExecution if (currentStream.isInstanceOf[MicroBatchExecution] && lastExecution != null) { @@ -746,6 +753,27 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be QueryTest.sameRows(expectedAnswer, sparkAnswer).foreach { error => failTest(error) } + + case CheckSchema(expectedSchema) => + waitAllDataProcessed() + val resultSchema = currentStream.lastExecution.analyzed.schema + if (expectedSchema != resultSchema) { + failTest( + s""" + |== Results == + |${ + sideBySide( + s""" + |== Correct Schema == + |${expectedSchema.simpleString} + """.stripMargin, + s""" + |== Spark Result Schema == + |${resultSchema.simpleString} + """.stripMargin).mkString("\n") + } + """.stripMargin) + } } pos += 1 } From 7870f307445ddfc7ac34a077afd6bb2b6ca96e1c Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 22 May 2018 16:29:51 +0900 Subject: [PATCH 4/9] Fix --- .../spark/sql/catalyst/json/JSONOptions.scala | 13 +---- .../sql/catalyst/json/JacksonParser.scala | 1 - .../datasources/json/JsonFileFormat.scala | 14 ++--- .../streaming/FileStreamSource.scala | 28 +++------- .../sql/streaming/FileStreamSourceSuite.scala | 55 +------------------ .../spark/sql/streaming/StreamTest.scala | 42 +++----------- 6 files changed, 24 insertions(+), 129 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 872c0c09a064f..84bcdae4238fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -73,17 +73,8 @@ private[sql] class JSONOptions( val columnNameOfCorruptRecord = parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) - // Whether file-based streaming sources will ignore column of all null values or empty - // array during JSON schema inference. - val dropFieldIfAllNull = { - val streamingSchemaInference = - parameters.get("streamingSchemaInference").map(_.toBoolean).getOrElse(false) - val _dropFieldIfAllNull = - parameters.get("dropFieldIfAllNull").map(_.toBoolean).getOrElse(false) - - // We could set true at `dropFieldIfAllNull` iff `spark.sql.streaming.schemaInference` enabled - streamingSchemaInference && _dropFieldIfAllNull - } + // Whether to ignore column of all null values or empty array during JSON schema inference + val dropFieldIfAllNull = parameters.get("dropFieldIfAllNull").map(_.toBoolean).getOrElse(false) val timeZone: TimeZone = DateTimeUtils.getTimeZone( parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 61753bf241a33..499c79503422b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -330,7 +330,6 @@ class JacksonParser( values += fieldConverter.apply(parser) } // Canonicalize arrays; an array is null if all its elements are null - // TODO: Reconsider this if (options.dropFieldIfAllNull && values.forall(_ == null)) { null } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index d7ecb78bd7e84..3b04510d29695 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -52,11 +52,10 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { - val sqlConf = sparkSession.sessionState.conf val parsedOptions = new JSONOptions( - options + ("streamingSchemaInference" -> sqlConf.streamingSchemaInference.toString), - sqlConf.sessionLocalTimeZone, - sqlConf.columnNameOfCorruptRecord) + options, + sparkSession.sessionState.conf.sessionLocalTimeZone, + sparkSession.sessionState.conf.columnNameOfCorruptRecord) JsonDataSource(parsedOptions).inferSchema( sparkSession, files, parsedOptions) } @@ -100,11 +99,10 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - val sqlConf = sparkSession.sessionState.conf val parsedOptions = new JSONOptions( - options + ("streamingSchemaInference" -> sqlConf.streamingSchemaInference.toString), - sqlConf.sessionLocalTimeZone, - sqlConf.columnNameOfCorruptRecord) + options, + sparkSession.sessionState.conf.sessionLocalTimeZone, + sparkSession.sessionState.conf.columnNameOfCorruptRecord) val actualSchema = StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 3eed0745c8075..8c016abc5b643 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -19,13 +19,15 @@ package org.apache.spark.sql.execution.streaming import java.net.URI +import scala.collection.JavaConverters._ + import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.execution.datasources.{DataSource, InMemoryFileIndex, LogicalRelation} -import org.apache.spark.sql.types.{NullType, StructType} +import org.apache.spark.sql.types.StructType /** * A very simple source that reads files from the given directory as they appear. @@ -34,17 +36,13 @@ class FileStreamSource( sparkSession: SparkSession, path: String, fileFormatClassName: String, - initialSchema: StructType, + override val schema: StructType, partitionColumns: Seq[String], metadataPath: String, options: Map[String, String]) extends Source with Logging { import FileStreamSource._ - private var _schema: Option[StructType] = None - - override def schema: StructType = _schema.getOrElse(initialSchema) - private val sourceOptions = new FileStreamOptions(options) private val hadoopConf = sparkSession.sessionState.newHadoopConf() @@ -165,28 +163,16 @@ class FileStreamSource( val files = metadataLog.get(Some(startOffset + 1), Some(endOffset)).flatMap(_._2) logInfo(s"Processing ${files.length} files from ${startOffset + 1}:$endOffset") logTrace(s"Files are:\n\t" + files.mkString("\n\t")) - - // If the current schema has `NullType`s, we will trigger schema inference again in - // the current batch. - val doInferSchemaInCurrentBatch = schema.existsRecursively(_.acceptsType(NullType)) - val newDataSource = DataSource( sparkSession, paths = files.map(f => new Path(new URI(f.path)).toString), - userSpecifiedSchema = if (!doInferSchemaInCurrentBatch) Some(schema) else None, + userSpecifiedSchema = Some(schema), partitionColumns = partitionColumns, className = fileFormatClassName, options = optionsWithPartitionBasePath) - - val rel = newDataSource.resolveRelation(checkFilesExist = false) - // If schema inference triggered in the current batch, replaces the current `schema` - // with the inferred one. - if (doInferSchemaInCurrentBatch) { - _schema = Some(rel.schema) - } - - Dataset.ofRows(sparkSession, LogicalRelation(rel, isStreaming = true)) + Dataset.ofRows(sparkSession, LogicalRelation(newDataSource.resolveRelation( + checkFilesExist = false), isStreaming = true)) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index ce25d9251f3c1..d4bd9c7987f2d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -24,11 +24,11 @@ import scala.util.Random import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem} import org.scalatest.PrivateMethodTester +import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.FileStreamSource.{FileEntry, SeenFilesMap} import org.apache.spark.sql.internal.SQLConf @@ -624,56 +624,6 @@ class FileStreamSourceSuite extends FileStreamSourceTest { } } - test("SPARK-23772 Ignore column of all null values or empty array during JSON schema inference") { - withTempDirs { case (src, tmp) => - withSQLConf(SQLConf.STREAMING_SCHEMA_INFERENCE.key -> "true") { - - // Add a file so that we can infer its schema - stringToFile(new File(src, "existing"), - "{'c0': 1, 'c1': null, 'c2': []}\n{'c0': 2, 'c1': null, 'c2': null}") - - val fileStream = createFileStream( - "json", src.getCanonicalPath, options = Map("dropFieldIfAllNull" -> "true")) - - // FileStreamSource should infer the column "k" - assert(fileStream.schema === StructType( - StructField("c0", LongType) :: - StructField("c1", NullType) :: - StructField("c2", ArrayType(NullType)) :: Nil)) - - testStream(fileStream)( - - // Should not pick up column v in the file added before start - AddTextFileData("{'c0': 3, 'c1': 3.8, 'c2': [null, null]}", src, tmp), - CheckSchema(StructType( - StructField("c0", LongType) :: - StructField("c1", DoubleType) :: - StructField("c2", ArrayType(NullType)) :: Nil)), - // - // CheckAnswer(Row(1, null, Array()) :: Row(2, null, null) :: - // Row(3, 3.8, Array(null, null)) :: Nil: _*), - // - // Canonicalize arrays; an array is null if all its elements are null - CheckAnswer(Row(1, null, null) :: Row(2, null, null) :: Row(3, 3.8, null) - :: Nil: _*), - - // Should read data in column k, and ignore v - AddTextFileData("{'c0': 4, 'c1': 1.1, 'c2': [1, 2, 3]}", src, tmp), - CheckSchema(StructType( - StructField("c0", LongType) :: - StructField("c1", DoubleType) :: - StructField("c2", ArrayType(LongType)) :: Nil)), - // - // CheckAnswer(Row(1, null, Array()) :: Row(2, null, null) :: - // Row(3, 3.8, Array(null, null)) :: Row(4, 1.1, Array(1, 2, 3)) :: Nil: _*)) - // - // Canonicalize arrays; an array is null if all its elements are null - CheckAnswer(Row(1, null, null) :: Row(2, null, null) :: - Row(3, 3.8, null) :: Row(4, 1.1, Array(1, 2, 3)) :: Nil: _*)) - } - } - } - // =============== ORC file stream tests ================ test("read from orc files") { @@ -1165,8 +1115,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest { val df = spark.readStream.format("text").load(src.getCanonicalPath).map(_ + "-x") // Test `explain` not throwing errors - val explainCmd = ExplainCommand(df.queryExecution.logical, extended = false) - spark.sessionState.executePlan(explainCmd).executedPlan + df.explain() val q = df.writeStream.queryName("file_explain").format("memory").start() .asInstanceOf[StreamingQueryWrapper] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 65dc3e727e298..9d139a927bea5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -44,9 +44,9 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch} import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.StructType import org.apache.spark.util.{Clock, SystemClock, Utils} /** @@ -203,8 +203,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be private def operatorName = if (lastOnly) "CheckLastBatchByFunc" else "CheckAnswerByFunc" } - case class CheckSchema(expectedSchema: StructType) extends StreamAction with StreamMustBeRunning - case class CheckNewAnswerRows(expectedAnswer: Seq[Row]) extends StreamAction with StreamMustBeRunning { override def toString: String = s"$operatorName: ${expectedAnswer.mkString(",")}" @@ -457,7 +455,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be var lastFetchedMemorySinkLastBatchId: Long = -1 - def waitAllDataProcessed() = { + def fetchStreamAnswer( + currentStream: StreamExecution, + lastOnly: Boolean = false, + sinceLastFetchOnly: Boolean = false) = { + verify( + !(lastOnly && sinceLastFetchOnly), "both lastOnly and sinceLastFetchOnly cannot be true") verify(currentStream != null, "stream not running") // Block until all data added has been processed for all the source @@ -470,16 +473,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } } } - } - - def fetchStreamAnswer( - currentStream: StreamExecution, - lastOnly: Boolean = false, - sinceLastFetchOnly: Boolean = false) = { - verify( - !(lastOnly && sinceLastFetchOnly), "both lastOnly and sinceLastFetchOnly cannot be true") - - waitAllDataProcessed() val lastExecution = currentStream.lastExecution if (currentStream.isInstanceOf[MicroBatchExecution] && lastExecution != null) { @@ -753,27 +746,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be QueryTest.sameRows(expectedAnswer, sparkAnswer).foreach { error => failTest(error) } - - case CheckSchema(expectedSchema) => - waitAllDataProcessed() - val resultSchema = currentStream.lastExecution.analyzed.schema - if (expectedSchema != resultSchema) { - failTest( - s""" - |== Results == - |${ - sideBySide( - s""" - |== Correct Schema == - |${expectedSchema.simpleString} - """.stripMargin, - s""" - |== Spark Result Schema == - |${resultSchema.simpleString} - """.stripMargin).mkString("\n") - } - """.stripMargin) - } } pos += 1 } From 6c4592d2b8f008adfed79a31a8373641d4f4f550 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 23 May 2018 00:00:08 +0900 Subject: [PATCH 5/9] Fix --- .../spark/sql/catalyst/json/JSONOptions.scala | 2 +- .../sql/catalyst/json/JacksonParser.scala | 2 +- .../apache/spark/sql/DataFrameReader.scala | 2 ++ .../execution/datasources/DataSource.scala | 2 +- .../datasources/json/JsonInferSchema.scala | 16 +++++++++++---- .../datasources/json/JsonSuite.scala | 20 +++++++++++++++++++ 6 files changed, 37 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 84bcdae4238fb..f2a48ccf4526a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -73,7 +73,7 @@ private[sql] class JSONOptions( val columnNameOfCorruptRecord = parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) - // Whether to ignore column of all null values or empty array during JSON schema inference + // Whether to ignore column of all null values or empty array/struct during schema inference val dropFieldIfAllNull = parameters.get("dropFieldIfAllNull").map(_.toBoolean).getOrElse(false) val timeZone: TimeZone = DateTimeUtils.getTimeZone( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 499c79503422b..9fd6c434326c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -79,7 +79,7 @@ class JacksonParser( val array = convertArray(parser, elementConverter) // Here, as we support reading top level JSON arrays and take every element // in such an array as a row, this case is possible. - if (array.numElements() == 0) { + if (array == null || array.numElements() == 0) { Nil } else { array.toArray[InternalRow](schema).toSeq diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 53f44888ebaff..ff066629649b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -379,6 +379,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * that should be used for parsing. *
  • `samplingRatio` (default is 1.0): defines fraction of input JSON objects used * for schema inferring.
  • + *
  • `dropFieldIfAllNull` (default `false`): whether to ignore column of all null values or + * empty array/struct during schema inference.
  • * * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index fed857d26cf0a..f16d824201e77 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -266,7 +266,7 @@ case class DataSource( sparkSession = sparkSession, path = path, fileFormatClassName = className, - initialSchema = sourceInfo.schema, + schema = sourceInfo.schema, partitionColumns = sourceInfo.partitionColumns, metadataPath = metadataPath, options = caseInsensitiveOptions) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index ec441a5ba1a46..35213dd9b596a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -126,9 +126,13 @@ private[sql] object JsonInferSchema { nullable = true) } val fields: Array[StructField] = builder.result() - // Note: other code relies on this sorting for correctness, so don't remove it! - java.util.Arrays.sort(fields, structFieldComparator) - StructType(fields) + if (configOptions.dropFieldIfAllNull && fields.isEmpty) { + NullType + } else { + // Note: other code relies on this sorting for correctness, so don't remove it! + java.util.Arrays.sort(fields, structFieldComparator) + StructType(fields) + } case START_ARRAY => // If this JSON array is empty, we use NullType as a placeholder. @@ -140,7 +144,11 @@ private[sql] object JsonInferSchema { elementType, inferField(parser, configOptions)) } - ArrayType(elementType) + if (configOptions.dropFieldIfAllNull && elementType == NullType) { + NullType + } else { + ArrayType(elementType) + } case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) if configOptions.primitivesAsString => StringType diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 0db688fec9a67..47dd85ed2a2bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -2408,4 +2408,24 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { spark.read.option("mode", "PERMISSIVE").option("encoding", "UTF-8").json(Seq(badJson).toDS()), Row(badJson)) } + + test("SPARK-23772 ignore column of all null values or empty array during schema inference") { + withTempPath { tempDir => + val path = tempDir.getAbsolutePath + Seq( + """{"a":null, "b":[null, null], "c":null, "d":[[], [null]], "e":{}}""", + """{"a":null, "b":[null], "c":[], "d": [null, []], "e":{}}""", + """{"a":null, "b":[], "c":[], "d": null, "e":null}""") + .toDS().write.mode("overwrite").text(path) + val df = spark.read.format("json") + .option("dropFieldIfAllNull", true) + .load(path) + val expectedSchema = new StructType() + .add("a", NullType).add("b", NullType).add("c", NullType).add("d", NullType) + .add("e", NullType) + assert(df.schema === expectedSchema) + val nullRow = Row(null, null, null, null, null) + checkAnswer(df, nullRow :: nullRow :: nullRow :: Nil) + } + } } From 907cf384cf6509091572543346ee948d24b561af Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 28 May 2018 22:09:53 +0900 Subject: [PATCH 6/9] Fix --- python/pyspark/sql/readwriter.py | 5 +- .../sql/catalyst/json/JacksonParser.scala | 15 +++--- .../sql/streaming/DataStreamReader.scala | 2 + .../datasources/json/JsonSuite.scala | 51 +++++++++++++++---- 4 files changed, 56 insertions(+), 17 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 448a4732001b5..cd3c0df006f6e 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -177,7 +177,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None, - encoding=None): + dropFieldIfAllNull=None, encoding=None): """ Loads JSON files and returns the results as a :class:`DataFrame`. @@ -246,6 +246,9 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. :param samplingRatio: defines fraction of input JSON objects used for schema inferring. If None is set, it uses the default value, ``1.0``. + :param dropFieldIfAllNull: whether to ignore column of all null values or empty + array/struct during schema inference. If None is set, it + uses the default value, ``false``. >>> df1 = spark.read.json('python/test_support/sql/people.json') >>> df1.dtypes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 9fd6c434326c5..f4316b90613cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -79,7 +79,7 @@ class JacksonParser( val array = convertArray(parser, elementConverter) // Here, as we support reading top level JSON arrays and take every element // in such an array as a row, this case is possible. - if (array == null || array.numElements() == 0) { + if (array.numElements() == 0) { Nil } else { array.toArray[InternalRow](schema).toSeq @@ -234,6 +234,11 @@ class JacksonParser( case udt: UserDefinedType[_] => makeConverter(udt.sqlType) + case _: NullType if options.dropFieldIfAllNull => + (parser: JsonParser) => parseJsonToken[Null](parser, dataType) { + case _ => null + } + case _ => (parser: JsonParser) => // Here, we pass empty `PartialFunction` so that this case can be @@ -329,12 +334,8 @@ class JacksonParser( while (nextUntil(parser, JsonToken.END_ARRAY)) { values += fieldConverter.apply(parser) } - // Canonicalize arrays; an array is null if all its elements are null - if (options.dropFieldIfAllNull && values.forall(_ == null)) { - null - } else { - new GenericArrayData(values.toArray) - } + + new GenericArrayData(values.toArray) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index ae93965bc50ed..ef8dc3a325a33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -270,6 +270,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * per file *
  • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator * that should be used for parsing.
  • + *
  • `dropFieldIfAllNull` (default `false`): whether to ignore column of all null values or + * empty array/struct during schema inference.
  • * * * @since 2.0.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 47dd85ed2a2bd..1fa3bed91cf53 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -2412,20 +2412,53 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("SPARK-23772 ignore column of all null values or empty array during schema inference") { withTempPath { tempDir => val path = tempDir.getAbsolutePath + + // primitive types Seq( - """{"a":null, "b":[null, null], "c":null, "d":[[], [null]], "e":{}}""", - """{"a":null, "b":[null], "c":[], "d": [null, []], "e":{}}""", - """{"a":null, "b":[], "c":[], "d": null, "e":null}""") + """{"a":null, "b":1, "c":3.0}""", + """{"a":null, "b":null, "c":"string"}""", + """{"a":null, "b":null, "c":null}""") + .toDS().write.text(path) + var df = spark.read.format("json") + .option("dropFieldIfAllNull", true) + .load(path) + var expectedSchema = new StructType() + .add("a", NullType).add("b", LongType).add("c", StringType) + assert(df.schema === expectedSchema) + checkAnswer(df, Row(null, 1, "3.0") :: Row(null, null, "string") :: Row(null, null, null) + :: Nil) + + // arrays + Seq( + """{"a":[2, 1], "b":[null, null], "c":null, "d":[[], [null]]}""", + """{"a":[null], "b":[null], "c":[], "d": [null, []]}""", + """{"a":null, "b":null, "c":[], "d": null}""") .toDS().write.mode("overwrite").text(path) - val df = spark.read.format("json") + df = spark.read.format("json") + .option("dropFieldIfAllNull", true) + .load(path) + expectedSchema = new StructType() + .add("a", ArrayType(LongType)).add("b", NullType).add("c", NullType).add("d", NullType) + assert(df.schema === expectedSchema) + checkAnswer(df, Row(Array(2, 1), null, null, null) :: Row(Array(null), null, null, null) :: + Row(null, null, null, null) :: Nil) + + // structs + Seq( + """{"a": {"a1": 1, "a2":"string"}, "b":{}}""", + """{"a": {"a1": 2, "a2":null}, "b":{}}""", + """{"a": null, "b":null}""") + .toDS().write.mode("overwrite").text(path) + df = spark.read.format("json") .option("dropFieldIfAllNull", true) .load(path) - val expectedSchema = new StructType() - .add("a", NullType).add("b", NullType).add("c", NullType).add("d", NullType) - .add("e", NullType) + expectedSchema = new StructType() + .add("a", StructType(StructField("a1", LongType) :: StructField("a2", StringType) + :: Nil)) + .add("b", NullType) assert(df.schema === expectedSchema) - val nullRow = Row(null, null, null, null, null) - checkAnswer(df, nullRow :: nullRow :: nullRow :: Nil) + checkAnswer(df, Row(Row(1, "string"), null) :: Row(Row(2, null), null) :: + Row(null, null) :: Nil) } } } From 58054ef61f61a999117ec8617eed34e446ddb078 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 14 Jun 2018 08:46:19 +0900 Subject: [PATCH 7/9] Add tests --- .../execution/datasources/json/JsonSuite.scala | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 1fa3bed91cf53..72fddeaa1a0a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -2430,24 +2430,26 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { // arrays Seq( - """{"a":[2, 1], "b":[null, null], "c":null, "d":[[], [null]]}""", - """{"a":[null], "b":[null], "c":[], "d": [null, []]}""", - """{"a":null, "b":null, "c":[], "d": null}""") + """{"a":[2, 1], "b":[null, null], "c":null, "d":[[], [null]], "e":[[], null, [[]]]}""", + """{"a":[null], "b":[null], "c":[], "d":[null, []], "e":null}""", + """{"a":null, "b":null, "c":[], "d":null, "e":[null, [], null]}""") .toDS().write.mode("overwrite").text(path) df = spark.read.format("json") .option("dropFieldIfAllNull", true) .load(path) expectedSchema = new StructType() .add("a", ArrayType(LongType)).add("b", NullType).add("c", NullType).add("d", NullType) + .add("e", NullType) assert(df.schema === expectedSchema) - checkAnswer(df, Row(Array(2, 1), null, null, null) :: Row(Array(null), null, null, null) :: - Row(null, null, null, null) :: Nil) + checkAnswer(df, Row(Array(2, 1), null, null, null, null) :: + Row(Array(null), null, null, null, null) :: + Row(null, null, null, null, null) :: Nil) // structs Seq( - """{"a": {"a1": 1, "a2":"string"}, "b":{}}""", - """{"a": {"a1": 2, "a2":null}, "b":{}}""", - """{"a": null, "b":null}""") + """{"a":{"a1": 1, "a2":"string"}, "b":{}}""", + """{"a":{"a1": 2, "a2":null}, "b":{}}""", + """{"a":null, "b":null}""") .toDS().write.mode("overwrite").text(path) df = spark.read.format("json") .option("dropFieldIfAllNull", true) From 22e0d9f12e4b08a4337c61371cf4ff795a2752b2 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 14 Jun 2018 09:25:04 +0900 Subject: [PATCH 8/9] Brush up code --- .../datasources/json/JsonInferSchema.scala | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index 35213dd9b596a..2f165a80884fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -126,13 +126,9 @@ private[sql] object JsonInferSchema { nullable = true) } val fields: Array[StructField] = builder.result() - if (configOptions.dropFieldIfAllNull && fields.isEmpty) { - NullType - } else { - // Note: other code relies on this sorting for correctness, so don't remove it! - java.util.Arrays.sort(fields, structFieldComparator) - StructType(fields) - } + // Note: other code relies on this sorting for correctness, so don't remove it! + java.util.Arrays.sort(fields, structFieldComparator) + StructType(fields) case START_ARRAY => // If this JSON array is empty, we use NullType as a placeholder. @@ -144,11 +140,7 @@ private[sql] object JsonInferSchema { elementType, inferField(parser, configOptions)) } - if (configOptions.dropFieldIfAllNull && elementType == NullType) { - NullType - } else { - ArrayType(elementType) - } + ArrayType(elementType) case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) if configOptions.primitivesAsString => StringType @@ -184,16 +176,25 @@ private[sql] object JsonInferSchema { } /** - * Convert NullType to StringType and remove StructTypes with no fields + * Canonicalize inferred types, e.g., convert NullType to StringType and remove StructTypes + * with no fields. */ private def canonicalizeType(tpe: DataType, options: JSONOptions): Option[DataType] = tpe match { case at @ ArrayType(elementType, _) => - for { + val canonicalizeArrayOption = for { canonicalType <- canonicalizeType(elementType, options) } yield { at.copy(canonicalType) } + canonicalizeArrayOption.map { array => + if (options.dropFieldIfAllNull && array.elementType == NullType) { + NullType + } else { + array + } + } + case StructType(fields) => val canonicalFields: Array[StructField] = for { field <- fields @@ -205,6 +206,8 @@ private[sql] object JsonInferSchema { if (canonicalFields.length > 0) { Some(StructType(canonicalFields)) + } else if (options.dropFieldIfAllNull) { + Some(NullType) } else { // per SPARK-8093: empty structs should be deleted None From 4544433760bd70cff41aa8e8bb718e6de0e3b877 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sun, 17 Jun 2018 14:55:07 -0700 Subject: [PATCH 9/9] update impl --- .../sql/catalyst/json/JacksonParser.scala | 5 -- .../datasources/json/JsonInferSchema.scala | 47 +++++++------------ .../datasources/json/JsonSuite.scala | 18 +++---- 3 files changed, 24 insertions(+), 46 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index f4316b90613cc..a5a4a13eb608b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -234,11 +234,6 @@ class JacksonParser( case udt: UserDefinedType[_] => makeConverter(udt.sqlType) - case _: NullType if options.dropFieldIfAllNull => - (parser: JsonParser) => parseJsonToken[Null](parser, dataType) { - case _ => null - } - case _ => (parser: JsonParser) => // Here, we pass empty `PartialFunction` so that this case can be diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index 2f165a80884fc..97ed1dc35c97c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -176,44 +176,33 @@ private[sql] object JsonInferSchema { } /** - * Canonicalize inferred types, e.g., convert NullType to StringType and remove StructTypes - * with no fields. + * Recursively canonicalizes inferred types, e.g., removes StructTypes with no fields, + * drops NullTypes or converts them to StringType based on provided options. */ private def canonicalizeType(tpe: DataType, options: JSONOptions): Option[DataType] = tpe match { - case at @ ArrayType(elementType, _) => - val canonicalizeArrayOption = for { - canonicalType <- canonicalizeType(elementType, options) - } yield { - at.copy(canonicalType) - } - - canonicalizeArrayOption.map { array => - if (options.dropFieldIfAllNull && array.elementType == NullType) { - NullType - } else { - array - } - } + case at: ArrayType => + canonicalizeType(at.elementType, options) + .map(t => at.copy(elementType = t)) case StructType(fields) => - val canonicalFields: Array[StructField] = for { - field <- fields - if field.name.length > 0 - canonicalType <- canonicalizeType(field.dataType, options) - } yield { - field.copy(dataType = canonicalType) + val canonicalFields = fields.filter(_.name.nonEmpty).flatMap { f => + canonicalizeType(f.dataType, options) + .map(t => f.copy(dataType = t)) } - - if (canonicalFields.length > 0) { - Some(StructType(canonicalFields)) - } else if (options.dropFieldIfAllNull) { - Some(NullType) + // SPARK-8093: empty structs should be deleted + if (canonicalFields.isEmpty) { + None } else { - // per SPARK-8093: empty structs should be deleted + Some(StructType(canonicalFields)) + } + + case NullType => + if (options.dropFieldIfAllNull) { None + } else { + Some(StringType) } - case NullType if !options.dropFieldIfAllNull => Some(StringType) case other => Some(other) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 72fddeaa1a0a9..0e4523bfe088c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -2423,10 +2423,9 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .option("dropFieldIfAllNull", true) .load(path) var expectedSchema = new StructType() - .add("a", NullType).add("b", LongType).add("c", StringType) + .add("b", LongType).add("c", StringType) assert(df.schema === expectedSchema) - checkAnswer(df, Row(null, 1, "3.0") :: Row(null, null, "string") :: Row(null, null, null) - :: Nil) + checkAnswer(df, Row(1, "3.0") :: Row(null, "string") :: Row(null, null) :: Nil) // arrays Seq( @@ -2438,17 +2437,14 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .option("dropFieldIfAllNull", true) .load(path) expectedSchema = new StructType() - .add("a", ArrayType(LongType)).add("b", NullType).add("c", NullType).add("d", NullType) - .add("e", NullType) + .add("a", ArrayType(LongType)) assert(df.schema === expectedSchema) - checkAnswer(df, Row(Array(2, 1), null, null, null, null) :: - Row(Array(null), null, null, null, null) :: - Row(null, null, null, null, null) :: Nil) + checkAnswer(df, Row(Array(2, 1)) :: Row(Array(null)) :: Row(null) :: Nil) // structs Seq( """{"a":{"a1": 1, "a2":"string"}, "b":{}}""", - """{"a":{"a1": 2, "a2":null}, "b":{}}""", + """{"a":{"a1": 2, "a2":null}, "b":{"b1":[null]}}""", """{"a":null, "b":null}""") .toDS().write.mode("overwrite").text(path) df = spark.read.format("json") @@ -2457,10 +2453,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { expectedSchema = new StructType() .add("a", StructType(StructField("a1", LongType) :: StructField("a2", StringType) :: Nil)) - .add("b", NullType) assert(df.schema === expectedSchema) - checkAnswer(df, Row(Row(1, "string"), null) :: Row(Row(2, null), null) :: - Row(null, null) :: Nil) + checkAnswer(df, Row(Row(1, "string")) :: Row(Row(2, null)) :: Row(null) :: Nil) } } }