diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 1a1954281cd06..9954ae4c44512 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -73,8 +73,10 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * @since 1.4.0 */ def schema(schema: StructType): DataFrameReader = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType] - this.userSpecifiedSchema = Option(replaced) + if (schema != null) { + val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType] + this.userSpecifiedSchema = Option(replaced) + } this } @@ -90,10 +92,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * @since 2.3.0 */ def schema(schemaString: String): DataFrameReader = { - val rawSchema = StructType.fromDDL(schemaString) - val schema = CharVarcharUtils.failIfHasCharVarchar(rawSchema).asInstanceOf[StructType] - this.userSpecifiedSchema = Option(schema) - this + schema(StructType.fromDDL(schemaString)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 2ed2487c83b01..a6913fab97a40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -64,8 +64,10 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * @since 2.0.0 */ def schema(schema: StructType): DataStreamReader = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType] - this.userSpecifiedSchema = Option(replaced) + if (schema != null) { + val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType] + this.userSpecifiedSchema = Option(replaced) + } this } @@ -77,10 +79,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * @since 2.3.0 */ def schema(schemaString: String): DataStreamReader = { - val rawSchema = StructType.fromDDL(schemaString) - val schema = CharVarcharUtils.failIfHasCharVarchar(rawSchema).asInstanceOf[StructType] - this.userSpecifiedSchema = Option(schema) - this + schema(StructType.fromDDL(schemaString)) } /**