Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ private[spark] object SQLConf {
val PARQUET_BINARY_AS_STRING = "spark.sql.parquet.binaryAsString"
val PARQUET_CACHE_METADATA = "spark.sql.parquet.cacheMetadata"
val PARQUET_COMPRESSION = "spark.sql.parquet.compression.codec"
val COLUMN_NAME_OF_CORRUPT_RECORD = "spark.sql.columnNameOfCorruptRecord"

// This is only used for the thriftserver
val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool"
Expand Down Expand Up @@ -131,6 +132,9 @@ private[sql] trait SQLConf {
private[spark] def inMemoryPartitionPruning: Boolean =
getConf(IN_MEMORY_PARTITION_PRUNING, "false").toBoolean

private[spark] def columnNameOfCorruptRecord: String =
getConf(COLUMN_NAME_OF_CORRUPT_RECORD, "_corrupt_record")

/** ********************** SQLConf functionality methods ************ */

/** Set Spark SQL configuration properties. */
Expand Down
14 changes: 10 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,12 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
@Experimental
def jsonRDD(json: RDD[String], schema: StructType): SchemaRDD = {
val columnNameOfCorruptJsonRecord = columnNameOfCorruptRecord
val appliedSchema =
Option(schema).getOrElse(JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, 1.0)))
val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema)
Option(schema).getOrElse(
JsonRDD.nullTypeToStringType(
JsonRDD.inferSchema(json, 1.0, columnNameOfCorruptJsonRecord)))
val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
applySchema(rowRDD, appliedSchema)
}

Expand All @@ -206,8 +209,11 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
@Experimental
def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = {
val appliedSchema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, samplingRatio))
val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema)
val columnNameOfCorruptJsonRecord = columnNameOfCorruptRecord
val appliedSchema =
JsonRDD.nullTypeToStringType(
JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord))
val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
applySchema(rowRDD, appliedSchema)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,12 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
* It goes through the entire dataset once to determine the schema.
*/
def jsonRDD(json: JavaRDD[String]): JavaSchemaRDD = {
val appliedScalaSchema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0))
val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema)
val columnNameOfCorruptJsonRecord = sqlContext.columnNameOfCorruptRecord
val appliedScalaSchema =
JsonRDD.nullTypeToStringType(
JsonRDD.inferSchema(json.rdd, 1.0, columnNameOfCorruptJsonRecord))
val scalaRowRDD =
JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema, columnNameOfCorruptJsonRecord)
val logicalPlan =
LogicalRDD(appliedScalaSchema.toAttributes, scalaRowRDD)(sqlContext)
new JavaSchemaRDD(sqlContext, logicalPlan)
Expand All @@ -162,10 +166,14 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
*/
@Experimental
def jsonRDD(json: JavaRDD[String], schema: StructType): JavaSchemaRDD = {
val columnNameOfCorruptJsonRecord = sqlContext.columnNameOfCorruptRecord
val appliedScalaSchema =
Option(asScalaDataType(schema)).getOrElse(
JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0))).asInstanceOf[SStructType]
val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema)
JsonRDD.nullTypeToStringType(
JsonRDD.inferSchema(
json.rdd, 1.0, columnNameOfCorruptJsonRecord))).asInstanceOf[SStructType]
val scalaRowRDD = JsonRDD.jsonStringToRow(
json.rdd, appliedScalaSchema, columnNameOfCorruptJsonRecord)
val logicalPlan =
LogicalRDD(appliedScalaSchema.toAttributes, scalaRowRDD)(sqlContext)
new JavaSchemaRDD(sqlContext, logicalPlan)
Expand Down
30 changes: 20 additions & 10 deletions sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.collection.Map
import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper}
import scala.math.BigDecimal

import com.fasterxml.jackson.core.JsonProcessingException
import com.fasterxml.jackson.databind.ObjectMapper

import org.apache.spark.rdd.RDD
Expand All @@ -34,16 +35,19 @@ private[sql] object JsonRDD extends Logging {

private[sql] def jsonStringToRow(
json: RDD[String],
schema: StructType): RDD[Row] = {
parseJson(json).map(parsed => asRow(parsed, schema))
schema: StructType,
columnNameOfCorruptRecords: String): RDD[Row] = {
parseJson(json, columnNameOfCorruptRecords).map(parsed => asRow(parsed, schema))
}

private[sql] def inferSchema(
json: RDD[String],
samplingRatio: Double = 1.0): StructType = {
samplingRatio: Double = 1.0,
columnNameOfCorruptRecords: String): StructType = {
require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0")
val schemaData = if (samplingRatio > 0.99) json else json.sample(false, samplingRatio, 1)
val allKeys = parseJson(schemaData).map(allKeysWithValueTypes).reduce(_ ++ _)
val allKeys =
parseJson(schemaData, columnNameOfCorruptRecords).map(allKeysWithValueTypes).reduce(_ ++ _)
createSchema(allKeys)
}

Expand Down Expand Up @@ -273,7 +277,9 @@ private[sql] object JsonRDD extends Logging {
case atom => atom
}

private def parseJson(json: RDD[String]): RDD[Map[String, Any]] = {
private def parseJson(
json: RDD[String],
columnNameOfCorruptRecords: String): RDD[Map[String, Any]] = {
// According to [Jackson-72: https://jira.codehaus.org/browse/JACKSON-72],
// ObjectMapper will not return BigDecimal when
// "DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS" is disabled
Expand All @@ -288,12 +294,16 @@ private[sql] object JsonRDD extends Logging {
// For example: for {"key": 1, "key":2}, we will get "key"->2.
val mapper = new ObjectMapper()
iter.flatMap { record =>
val parsed = mapper.readValue(record, classOf[Object]) match {
case map: java.util.Map[_, _] => scalafy(map).asInstanceOf[Map[String, Any]] :: Nil
case list: java.util.List[_] => scalafy(list).asInstanceOf[Seq[Map[String, Any]]]
}
try {
val parsed = mapper.readValue(record, classOf[Object]) match {
case map: java.util.Map[_, _] => scalafy(map).asInstanceOf[Map[String, Any]] :: Nil
case list: java.util.List[_] => scalafy(list).asInstanceOf[Seq[Map[String, Any]]]
}

parsed
parsed
} catch {
case e: JsonProcessingException => Map(columnNameOfCorruptRecords -> record) :: Nil
}
}
})
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.json.JsonRDD.{enforceCorrectType, compatibleType}
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._

class JsonSuite extends QueryTest {
Expand Down Expand Up @@ -636,7 +638,65 @@ class JsonSuite extends QueryTest {
("str_a_1", null, null) ::
("str_a_2", null, null) ::
(null, "str_b_3", null) ::
("str_a_4", "str_b_4", "str_c_4") ::Nil
("str_a_4", "str_b_4", "str_c_4") :: Nil
)
}

test("Corrupt records") {
// Test if we can query corrupt records.
val oldColumnNameOfCorruptRecord = TestSQLContext.columnNameOfCorruptRecord
TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed")

val jsonSchemaRDD = jsonRDD(corruptRecords)
jsonSchemaRDD.registerTempTable("jsonTable")

val schema = StructType(
StructField("_unparsed", StringType, true) ::
StructField("a", StringType, true) ::
StructField("b", StringType, true) ::
StructField("c", StringType, true) :: Nil)

assert(schema === jsonSchemaRDD.schema)

// In HiveContext, backticks should be used to access columns starting with a underscore.
checkAnswer(
sql(
"""
|SELECT a, b, c, _unparsed
|FROM jsonTable
""".stripMargin),
(null, null, null, "{") ::
(null, null, null, "") ::
(null, null, null, """{"a":1, b:2}""") ::
(null, null, null, """{"a":{, b:3}""") ::
("str_a_4", "str_b_4", "str_c_4", null) ::
(null, null, null, "]") :: Nil
)

checkAnswer(
sql(
"""
|SELECT a, b, c
|FROM jsonTable
|WHERE _unparsed IS NULL
""".stripMargin),
("str_a_4", "str_b_4", "str_c_4") :: Nil
)

checkAnswer(
sql(
"""
|SELECT _unparsed
|FROM jsonTable
|WHERE _unparsed IS NOT NULL
""".stripMargin),
Seq("{") ::
Seq("") ::
Seq("""{"a":1, b:2}""") ::
Seq("""{"a":{, b:3}""") ::
Seq("]") :: Nil
)

TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -143,4 +143,13 @@ object TestJsonData {
"""[{"a":"str_a_2"}, {"b":"str_b_3"}]""" ::
"""{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" ::
"""[]""" :: Nil)

val corruptRecords =
TestSQLContext.sparkContext.parallelize(
"""{""" ::
"""""" ::
"""{"a":1, b:2}""" ::
"""{"a":{, b:3}""" ::
"""{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" ::
"""]""" :: Nil)
}