Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ class JacksonParser(
private val emptyRow: Seq[InternalRow] = Seq(new GenericInternalRow(schema.length))

private val corruptFieldIndex = schema.getFieldIndex(options.columnNameOfCorruptRecord)
corruptFieldIndex.foreach(idx => require(schema(idx).dataType == StringType))
corruptFieldIndex.foreach { corrFieldIndex =>
require(schema(corrFieldIndex).dataType == StringType)
require(schema(corrFieldIndex).nullable)
}

@transient
private[this] var isWarningPrinted: Boolean = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.datasources.jdbc._
import org.apache.spark.sql.execution.datasources.json.JsonInferSchema
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.unsafe.types.UTF8String

/**
Expand Down Expand Up @@ -365,6 +365,15 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
createParser)
}

// Check a field requirement for corrupt records here to throw an exception in a driver side
schema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex =>
val f = schema(corruptFieldIndex)
if (f.dataType != StringType || !f.nullable) {
throw new AnalysisException(
"The field for corrupt records must be string type and nullable")
}
}

val parsed = jsonDataset.rdd.mapPartitions { iter =>
val parser = new JacksonParser(schema, parsedOptions)
iter.flatMap(parser.parse(_, createParser, UTF8String.fromString))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions}
import org.apache.spark.sql.catalyst.util.CompressionCodecs
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.util.SerializableConfiguration

class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
Expand Down Expand Up @@ -102,6 +102,15 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
sparkSession.sessionState.conf.sessionLocalTimeZone,
sparkSession.sessionState.conf.columnNameOfCorruptRecord)

// Check a field requirement for corrupt records here to throw an exception in a driver side
dataSchema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex =>
val f = dataSchema(corruptFieldIndex)
if (f.dataType != StringType || !f.nullable) {
throw new AnalysisException(
"The field for corrupt records must be string type and nullable")
}
}

(file: PartitionedFile) => {
val parser = new JacksonParser(requiredSchema, parsedOptions)
JsonDataSource(parsedOptions).readFile(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1944,4 +1944,35 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
assert(exceptionTwo.getMessage.contains("Malformed line in FAILFAST mode"))
}
}

test("Throw an exception if a `columnNameOfCorruptRecord` field violates requirements") {
val columnNameOfCorruptRecord = "_unparsed"
val schema = StructType(
StructField(columnNameOfCorruptRecord, IntegerType, true) ::
StructField("a", StringType, true) ::
StructField("b", StringType, true) ::
StructField("c", StringType, true) :: Nil)
val errMsg = intercept[AnalysisException] {
spark.read
.option("mode", "PERMISSIVE")
.option("columnNameOfCorruptRecord", columnNameOfCorruptRecord)
.schema(schema)
.json(corruptRecords)
}.getMessage
assert(errMsg.startsWith("The field for corrupt records must be string type and nullable"))

withTempPath { dir =>
val path = dir.getCanonicalPath
corruptRecords.toDF("value").write.text(path)
val errMsg = intercept[AnalysisException] {
spark.read
.option("mode", "PERMISSIVE")
.option("columnNameOfCorruptRecord", columnNameOfCorruptRecord)
.schema(schema)
.json(path)
.collect
}.getMessage
assert(errMsg.startsWith("The field for corrupt records must be string type and nullable"))
}
}
}