Skip to content

Commit 358300c

Browse files
adrian-wangmarmbrus
authored andcommitted
[SPARK-13056][SQL] map column would throw NPE if value is null
Jira: https://issues.apache.org/jira/browse/SPARK-13056 Create a map like { "a": "somestring", "b": null} Query like SELECT col["b"] FROM t1; NPE would be thrown. Author: Daoyuan Wang <[email protected]> Closes #10964 from adrian-wang/npewriter.
1 parent cba1d6b commit 358300c

File tree

2 files changed

+19
-6
lines changed

2 files changed

+19
-6
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
218218
protected override def nullSafeEval(value: Any, ordinal: Any): Any = {
219219
val baseValue = value.asInstanceOf[ArrayData]
220220
val index = ordinal.asInstanceOf[Number].intValue()
221-
if (index >= baseValue.numElements() || index < 0) {
221+
if (index >= baseValue.numElements() || index < 0 || baseValue.isNullAt(index)) {
222222
null
223223
} else {
224224
baseValue.get(index, dataType)
@@ -267,6 +267,7 @@ case class GetMapValue(child: Expression, key: Expression)
267267
val map = value.asInstanceOf[MapData]
268268
val length = map.numElements()
269269
val keys = map.keyArray()
270+
val values = map.valueArray()
270271

271272
var i = 0
272273
var found = false
@@ -278,10 +279,10 @@ case class GetMapValue(child: Expression, key: Expression)
278279
}
279280
}
280281

281-
if (!found) {
282+
if (!found || values.isNullAt(i)) {
282283
null
283284
} else {
284-
map.valueArray().get(i, dataType)
285+
values.get(i, dataType)
285286
}
286287
}
287288

@@ -291,10 +292,12 @@ case class GetMapValue(child: Expression, key: Expression)
291292
val keys = ctx.freshName("keys")
292293
val found = ctx.freshName("found")
293294
val key = ctx.freshName("key")
295+
val values = ctx.freshName("values")
294296
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
295297
s"""
296298
final int $length = $eval1.numElements();
297299
final ArrayData $keys = $eval1.keyArray();
300+
final ArrayData $values = $eval1.valueArray();
298301

299302
int $index = 0;
300303
boolean $found = false;
@@ -307,10 +310,10 @@ case class GetMapValue(child: Expression, key: Expression)
307310
}
308311
}
309312

310-
if ($found) {
311-
${ev.value} = ${ctx.getValue(eval1 + ".valueArray()", dataType, index)};
312-
} else {
313+
if (!$found || $values.isNullAt($index)) {
313314
${ev.isNull} = true;
315+
} else {
316+
${ev.value} = ${ctx.getValue(values, dataType, index)};
314317
}
315318
"""
316319
})

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2055,6 +2055,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
20552055
)
20562056
}
20572057

2058+
test("SPARK-13056: Null in map value causes NPE") {
2059+
val df = Seq(1 -> Map("abc" -> "somestring", "cba" -> null)).toDF("key", "value")
2060+
withTempTable("maptest") {
2061+
df.registerTempTable("maptest")
2062+
// local optimization will by pass codegen code, so we should keep the filter `key=1`
2063+
checkAnswer(sql("SELECT value['abc'] FROM maptest where key = 1"), Row("somestring"))
2064+
checkAnswer(sql("SELECT value['cba'] FROM maptest where key = 1"), Row(null))
2065+
}
2066+
}
2067+
20582068
test("hash function") {
20592069
val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j")
20602070
withTempTable("tbl") {

0 commit comments

Comments
 (0)