diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 64a00c0b8bfcd..7faa8662ebf03 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -1155,7 +1155,7 @@ def _test(): globs['sqlContext'] = SQLContext.getOrCreate(spark.sparkContext) globs['sdf'] = \ spark.readStream.format('text').load('python/test_support/sql/streaming') - globs['sdf_schema'] = StructType([StructField("data", StringType(), False)]) + globs['sdf_schema'] = StructType([StructField("data", StringType(), True)]) globs['df'] = \ globs['spark'].readStream.format('text').load('python/test_support/sql/streaming') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 80a7d4efe4e52..decfee5b713d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1140,6 +1140,15 @@ object SQLConf { .timeConf(TimeUnit.MILLISECONDS) .createWithDefault(TimeUnit.MINUTES.toMillis(10)) // 10 minutes + val FILE_SOURCE_SCHEMA_FORCE_NULLABLE = + buildConf("spark.sql.streaming.fileSource.schema.forceNullable") + .internal() + .doc("When true, force the schema of streaming file source to be nullable (including all " + + "the fields). Otherwise, the schema might not be compatible with actual data, which " + + "leads to corruptions.") + .booleanConf + .createWithDefault(true) + val STREAMING_SCHEMA_INFERENCE = buildConf("spark.sql.streaming.schemaInference") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 04ae528a1f6b3..0ccd99c88cf28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -254,9 +254,12 @@ case class DataSource( checkAndGlobPathIfNecessary(checkEmptyGlobPath = false, checkFilesExist = false) createInMemoryFileIndex(globbedPaths) }) + val forceNullable = + sparkSession.sessionState.conf.getConf(SQLConf.FILE_SOURCE_SCHEMA_FORCE_NULLABLE) + val sourceDataSchema = if (forceNullable) dataSchema.asNullable else dataSchema SourceInfo( s"FileSource[$path]", - StructType(dataSchema ++ partitionSchema), + StructType(sourceDataSchema ++ partitionSchema), partitionSchema.fieldNames) case _ => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 72f893845172d..f3f03715ee83a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -1577,6 +1577,25 @@ class FileStreamSourceSuite extends FileStreamSourceTest { ) } } + + test("SPARK-28651: force streaming file source to be nullable") { + withTempDir { temp => + val schema = StructType(Seq(StructField("foo", LongType, false))) + val nullableSchema = StructType(Seq(StructField("foo", LongType, true))) + val streamingSchema = spark.readStream.schema(schema).json(temp.getCanonicalPath).schema + assert(nullableSchema === streamingSchema) + + // Verify we have the same behavior as batch DataFrame. + val batchSchema = spark.read.schema(schema).json(temp.getCanonicalPath).schema + assert(batchSchema === streamingSchema) + + // Verify the flag works + withSQLConf(SQLConf.FILE_SOURCE_SCHEMA_FORCE_NULLABLE.key -> "false") { + val streamingSchema = spark.readStream.schema(schema).json(temp.getCanonicalPath).schema + assert(schema === streamingSchema) + } + } + } } class FileStreamSourceStressTestSuite extends FileStreamSourceTest {