Skip to content

Commit 46db418

Browse files
committed
Handle JSON arrays in the type of ArrayType(...(ArrayType(StructType))).
1 parent ed1980f commit 46db418

File tree

3 files changed

+96
-29
lines changed

3 files changed

+96
-29
lines changed

sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,15 @@ private[sql] object JsonRDD extends Logging {
6868
val (topLevel, structLike) = values.partition(_.size == 1)
6969
val topLevelFields = topLevel.filter {
7070
name => resolved.get(prefix ++ name).get match {
71-
case ArrayType(StructType(Nil), _) => false
72-
case ArrayType(_, _) => true
71+
case ArrayType(elementType, _) => {
72+
def hasInnerStruct(t: DataType): Boolean = t match {
73+
case s: StructType => false
74+
case ArrayType(t1, _) => hasInnerStruct(t1)
75+
case o => true
76+
}
77+
78+
hasInnerStruct(elementType)
79+
}
7380
case struct: StructType => false
7481
case _ => true
7582
}
@@ -84,7 +91,18 @@ private[sql] object JsonRDD extends Logging {
8491
val dataType = resolved.get(prefix :+ name).get
8592
dataType match {
8693
case array: ArrayType =>
87-
Some(StructField(name, ArrayType(structType, array.containsNull), nullable = true))
94+
// The pattern of this array is ArrayType(...(ArrayType(StructType))).
95+
// Since the inner struct of array is a placeholder (StructType(Nil)),
96+
// we need to replace this placeholder with the actual StructType (structType).
97+
def getActualArrayType(
98+
innerStruct: StructType,
99+
currentArray: ArrayType): ArrayType = currentArray match {
100+
case ArrayType(s: StructType, containsNull) =>
101+
ArrayType(innerStruct, containsNull)
102+
case ArrayType(a: ArrayType, containsNull) =>
103+
ArrayType(getActualArrayType(innerStruct, a), containsNull)
104+
}
105+
Some(StructField(name, getActualArrayType(structType, array), nullable = true))
88106
case struct: StructType => Some(StructField(name, structType, nullable = true))
89107
// dataType is StringType means that we have resolved type conflicts involving
90108
// primitive types and complex types. So, the type of name has been relaxed to
@@ -168,8 +186,7 @@ private[sql] object JsonRDD extends Logging {
168186
/**
169187
* Returns the element type of an JSON array. We go through all elements of this array
170188
* to detect any possible type conflict. We use [[compatibleType]] to resolve
171-
* type conflicts. Right now, when the element of an array is another array, we
172-
* treat the element as String.
189+
* type conflicts.
173190
*/
174191
private def typeOfArray(l: Seq[Any]): ArrayType = {
175192
val containsNull = l.exists(v => v == null)
@@ -216,18 +233,24 @@ private[sql] object JsonRDD extends Logging {
216233
}
217234
case (key: String, array: Seq[_]) => {
218235
// The value associated with the key is an array.
219-
typeOfArray(array) match {
236+
// Handle inner structs of an array.
237+
def buildKeyPathForInnerStructs(v: Any, t: DataType): Seq[(String, DataType)] = t match {
220238
case ArrayType(StructType(Nil), containsNull) => {
221239
// The elements of this arrays are structs.
222-
array.asInstanceOf[Seq[Map[String, Any]]].flatMap {
240+
v.asInstanceOf[Seq[Map[String, Any]]].flatMap {
223241
element => allKeysWithValueTypes(element)
224242
}.map {
225-
case (k, dataType) => (s"$key.$k", dataType)
226-
} :+ (key, ArrayType(StructType(Nil), containsNull))
243+
case (k, t) => (s"$key.$k", t)
244+
}
227245
}
228-
case ArrayType(elementType, containsNull) =>
229-
(key, ArrayType(elementType, containsNull)) :: Nil
246+
case ArrayType(t1, containsNull) =>
247+
v.asInstanceOf[Seq[Any]].flatMap {
248+
element => buildKeyPathForInnerStructs(element, t1)
249+
}
250+
case other => Nil
230251
}
252+
val elementType = typeOfArray(array)
253+
buildKeyPathForInnerStructs(array, elementType) :+ (key, elementType)
231254
}
232255
case (key: String, value) => (key, typeOfPrimitiveValue(value)) :: Nil
233256
}
@@ -339,15 +362,17 @@ private[sql] object JsonRDD extends Logging {
339362
null
340363
} else {
341364
desiredType match {
342-
case ArrayType(elementType, _) =>
343-
value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType))
344365
case StringType => toString(value)
345366
case IntegerType => value.asInstanceOf[IntegerType.JvmType]
346367
case LongType => toLong(value)
347368
case DoubleType => toDouble(value)
348369
case DecimalType => toDecimal(value)
349370
case BooleanType => value.asInstanceOf[BooleanType.JvmType]
350371
case NullType => null
372+
373+
case ArrayType(elementType, _) =>
374+
value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType))
375+
case struct: StructType => asRow(value.asInstanceOf[Map[String, Any]], struct)
351376
}
352377
}
353378
}
@@ -356,22 +381,9 @@ private[sql] object JsonRDD extends Logging {
356381
// TODO: Reuse the row instead of creating a new one for every record.
357382
val row = new GenericMutableRow(schema.fields.length)
358383
schema.fields.zipWithIndex.foreach {
359-
// StructType
360-
case (StructField(name, fields: StructType, _), i) =>
361-
row.update(i, json.get(name).flatMap(v => Option(v)).map(
362-
v => asRow(v.asInstanceOf[Map[String, Any]], fields)).orNull)
363-
364-
// ArrayType(StructType)
365-
case (StructField(name, ArrayType(structType: StructType, _), _), i) =>
366-
row.update(i,
367-
json.get(name).flatMap(v => Option(v)).map(
368-
v => v.asInstanceOf[Seq[Any]].map(
369-
e => asRow(e.asInstanceOf[Map[String, Any]], structType))).orNull)
370-
371-
// Other cases
372384
case (StructField(name, dataType, _), i) =>
373385
row.update(i, json.get(name).flatMap(v => Option(v)).map(
374-
enforceCorrectType(_, dataType)).getOrElse(null))
386+
enforceCorrectType(_, dataType)).orNull)
375387
}
376388

377389
row

sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -591,8 +591,35 @@ class JsonSuite extends QueryTest {
591591
(true, "str1") :: Nil
592592
)
593593
checkAnswer(
594-
sql("select complexArrayOfStruct[0].field1[1].inner2[0], complexArrayOfStruct[1].field2[0][1] from jsonTable"),
594+
sql(
595+
"""
596+
|select complexArrayOfStruct[0].field1[1].inner2[0], complexArrayOfStruct[1].field2[0][1]
597+
|from jsonTable
598+
""".stripMargin),
595599
("str2", 6) :: Nil
596600
)
597601
}
602+
603+
test("SPARK-3390 Complex arrays") {
604+
val jsonSchemaRDD = jsonRDD(complexFieldAndType2)
605+
jsonSchemaRDD.registerTempTable("jsonTable")
606+
607+
checkAnswer(
608+
sql(
609+
"""
610+
|select arrayOfArray1[0][0][0], arrayOfArray1[1][0][1], arrayOfArray1[1][1][0]
611+
|from jsonTable
612+
""".stripMargin),
613+
(5, 7, 8) :: Nil
614+
)
615+
checkAnswer(
616+
sql(
617+
"""
618+
|select arrayOfArray2[0][0][0].inner1, arrayOfArray2[1][0],
619+
|arrayOfArray2[1][1][1].inner2[0], arrayOfArray2[2][0][0].inner3[0][0].inner4
620+
|from jsonTable
621+
""".stripMargin),
622+
("str1", Nil, "str4", 2) :: Nil
623+
)
624+
}
598625
}

sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,34 @@ object TestJsonData {
106106
"inner1": "str4"
107107
}],
108108
"field2": [[5, 6], [7, 8]]
109-
}]
109+
}],
110+
"arrayOfArray1": [
111+
[
112+
[5]
113+
],
114+
[
115+
[6, 7],
116+
[8]
117+
]],
118+
"arrayOfArray2": [
119+
[
120+
[
121+
{
122+
"inner1": "str1"
123+
}
124+
]
125+
],
126+
[
127+
[],
128+
[
129+
{"inner2": ["str3", "str33"]},
130+
{"inner2": ["str4"], "inner1": "str11"}
131+
]
132+
],
133+
[
134+
[
135+
{"inner3": [[{"inner4": 2}]]}
136+
]
137+
]]
110138
}""" :: Nil)
111139
}

0 commit comments

Comments
 (0)