Skip to content

Commit 013c2ca

Browse files
committed
Add MapType containing null value support to Parquet.
1 parent 62989de commit 013c2ca

File tree

3 files changed

+29
-28
lines changed

3 files changed

+29
-28
lines changed

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -247,26 +247,24 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
247247
writer.endGroup()
248248
}
249249

250-
// TODO: support null values, see
251-
// https://issues.apache.org/jira/browse/SPARK-1649
252250
private[parquet] def writeMap(
253251
schema: MapType,
254252
map: CatalystConverter.MapScalaType[_, _]): Unit = {
255253
writer.startGroup()
256254
if (map.size > 0) {
257255
writer.startField(CatalystConverter.MAP_SCHEMA_NAME, 0)
258-
writer.startGroup()
259-
writer.startField(CatalystConverter.MAP_KEY_SCHEMA_NAME, 0)
260-
for(key <- map.keys) {
256+
for ((key, value) <- map) {
257+
writer.startGroup()
258+
writer.startField(CatalystConverter.MAP_KEY_SCHEMA_NAME, 0)
261259
writeValue(schema.keyType, key)
260+
writer.endField(CatalystConverter.MAP_KEY_SCHEMA_NAME, 0)
261+
if (value != null) {
262+
writer.startField(CatalystConverter.MAP_VALUE_SCHEMA_NAME, 1)
263+
writeValue(schema.valueType, value)
264+
writer.endField(CatalystConverter.MAP_VALUE_SCHEMA_NAME, 1)
265+
}
266+
writer.endGroup()
262267
}
263-
writer.endField(CatalystConverter.MAP_KEY_SCHEMA_NAME, 0)
264-
writer.startField(CatalystConverter.MAP_VALUE_SCHEMA_NAME, 1)
265-
for(value <- map.values) {
266-
writeValue(schema.valueType, value)
267-
}
268-
writer.endField(CatalystConverter.MAP_VALUE_SCHEMA_NAME, 1)
269-
writer.endGroup()
270268
writer.endField(CatalystConverter.MAP_SCHEMA_NAME, 0)
271269
}
272270
writer.endGroup()

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -129,25 +129,23 @@ private[parquet] object ParquetTypesConverter extends Logging {
129129
assert(
130130
keyValueGroup.getFieldCount == 2,
131131
"Parquet Map type malformatted: nested group should have 2 (key, value) fields!")
132-
val keyType = toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString)
133132
assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED)
133+
134+
val keyType = toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString)
134135
val valueType = toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString)
135-
assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED)
136-
// TODO: set valueContainsNull explicitly instead of assuming valueContainsNull is true
137-
// at here.
138-
MapType(keyType, valueType)
136+
MapType(keyType, valueType,
137+
keyValueGroup.getFields.apply(1).getRepetition != Repetition.REQUIRED)
139138
}
140139
case _ => {
141140
// Note: the order of these checks is important!
142141
if (correspondsToMap(groupType)) { // MapType
143142
val keyValueGroup = groupType.getFields.apply(0).asGroupType()
144-
val keyType = toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString)
145143
assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED)
144+
145+
val keyType = toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString)
146146
val valueType = toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString)
147-
assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED)
148-
// TODO: set valueContainsNull explicitly instead of assuming valueContainsNull is true
149-
// at here.
150-
MapType(keyType, valueType)
147+
MapType(keyType, valueType,
148+
keyValueGroup.getFields.apply(1).getRepetition != Repetition.REQUIRED)
151149
} else if (correspondsToArray(groupType)) { // ArrayType
152150
val elementType = toDataType(groupType.getFields.apply(0), isBinaryAsString)
153151
ArrayType(elementType, containsNull = false)
@@ -255,7 +253,7 @@ private[parquet] object ParquetTypesConverter extends Logging {
255253
}
256254
new ParquetGroupType(repetition, name, fields)
257255
}
258-
case MapType(keyType, valueType, _) => {
256+
case MapType(keyType, valueType, valueContainsNull) => {
259257
val parquetKeyType =
260258
fromDataType(
261259
keyType,
@@ -266,7 +264,7 @@ private[parquet] object ParquetTypesConverter extends Logging {
266264
fromDataType(
267265
valueType,
268266
CatalystConverter.MAP_VALUE_SCHEMA_NAME,
269-
nullable = false,
267+
nullable = valueContainsNull,
270268
inArray = false)
271269
ConversionPatterns.mapType(
272270
repetition,

sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ case class AllDataTypesWithNonPrimitiveType(
7878
booleanField: Boolean,
7979
binaryField: Array[Byte],
8080
array: Seq[Int],
81-
map: Map[Int, String],
81+
map: Map[Int, Long],
82+
mapValueContainsNull: Map[Int, Option[Long]],
8283
data: Data)
8384

8485
class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll {
@@ -193,7 +194,10 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
193194
.map(x => AllDataTypesWithNonPrimitiveType(
194195
s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0,
195196
(0 to x).map(_.toByte).toArray,
196-
(0 until x), (0 until x).map(i => i -> s"$i").toMap, Data((0 until x), Nested(x, s"$x"))))
197+
(0 until x),
198+
(0 until x).map(i => i -> i.toLong).toMap,
199+
(0 until x).map(i => i -> Option(i.toLong)).toMap + (x -> None),
200+
Data((0 until x), Nested(x, s"$x"))))
197201
.saveAsParquetFile(tempDir)
198202
val result = parquetFile(tempDir).collect()
199203
range.foreach {
@@ -208,8 +212,9 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
208212
assert(result(i).getBoolean(7) === (i % 2 == 0))
209213
assert(result(i)(8) === (0 to i).map(_.toByte).toArray)
210214
assert(result(i)(9) === (0 until i))
211-
assert(result(i)(10) === (0 until i).map(i => i -> s"$i").toMap)
212-
assert(result(i)(11) === new GenericRow(Array[Any]((0 until i), new GenericRow(Array[Any](i, s"$i")))))
215+
assert(result(i)(10) === (0 until i).map(i => i -> i.toLong).toMap)
216+
assert(result(i)(11) === (0 until i).map(i => i -> i.toLong).toMap + (i -> null))
217+
assert(result(i)(12) === new GenericRow(Array[Any]((0 until i), new GenericRow(Array[Any](i, s"$i")))))
213218
}
214219
}
215220

0 commit comments

Comments
 (0)