Skip to content

Commit 35a5853

Browse files
ueshinmarmbrus
authored andcommitted
[SPARK-2969][SQL] Make ScalaReflection be able to handle ArrayType.containsNull and MapType.valueContainsNull.
Make `ScalaReflection` be able to handle like: - `Seq[Int]` as `ArrayType(IntegerType, containsNull = false)` - `Seq[java.lang.Integer]` as `ArrayType(IntegerType, containsNull = true)` - `Map[Int, Long]` as `MapType(IntegerType, LongType, valueContainsNull = false)` - `Map[Int, java.lang.Long]` as `MapType(IntegerType, LongType, valueContainsNull = true)` Author: Takuya UESHIN <[email protected]> Closes #1889 from ueshin/issues/SPARK-2969 and squashes the following commits: 24f1c5c [Takuya UESHIN] Change the default value of ArrayType.containsNull to true in Python API. 79f5b65 [Takuya UESHIN] Change the default value of ArrayType.containsNull to true in Java API. 7cd1a7a [Takuya UESHIN] Fix json test failures. 2cfb862 [Takuya UESHIN] Change the default value of ArrayType.containsNull to true. 2f38e61 [Takuya UESHIN] Revert the default value of MapTypes.valueContainsNull. 9fa02f5 [Takuya UESHIN] Fix a test failure. 1a9a96b [Takuya UESHIN] Modify ScalaReflection to handle ArrayType.containsNull and MapType.valueContainsNull. (cherry picked from commit 98c2bb0) Signed-off-by: Michael Armbrust <[email protected]>
1 parent 83d2730 commit 35a5853

File tree

7 files changed

+49
-30
lines changed

7 files changed

+49
-30
lines changed

python/pyspark/sql.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,15 +186,15 @@ class ArrayType(DataType):
186186
187187
"""
188188

189-
def __init__(self, elementType, containsNull=False):
189+
def __init__(self, elementType, containsNull=True):
190190
"""Creates an ArrayType
191191
192192
:param elementType: the data type of elements.
193193
:param containsNull: indicates whether the list contains None values.
194194
195-
>>> ArrayType(StringType) == ArrayType(StringType, False)
195+
>>> ArrayType(StringType) == ArrayType(StringType, True)
196196
True
197-
>>> ArrayType(StringType, True) == ArrayType(StringType)
197+
>>> ArrayType(StringType, False) == ArrayType(StringType)
198198
False
199199
"""
200200
self.elementType = elementType

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,14 @@ object ScalaReflection {
6262
sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
6363
case t if t <:< typeOf[Seq[_]] =>
6464
val TypeRef(_, _, Seq(elementType)) = t
65-
Schema(ArrayType(schemaFor(elementType).dataType), nullable = true)
65+
val Schema(dataType, nullable) = schemaFor(elementType)
66+
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
6667
case t if t <:< typeOf[Map[_,_]] =>
6768
val TypeRef(_, _, Seq(keyType, valueType)) = t
68-
Schema(MapType(schemaFor(keyType).dataType, schemaFor(valueType).dataType), nullable = true)
69-
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
69+
val Schema(valueDataType, valueNullable) = schemaFor(valueType)
70+
Schema(MapType(schemaFor(keyType).dataType,
71+
valueDataType, valueContainsNull = valueNullable), nullable = true)
72+
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
7073
case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
7174
case t if t <:< typeOf[BigDecimal] => Schema(DecimalType, nullable = true)
7275
case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,8 +270,8 @@ case object FloatType extends FractionalType {
270270
}
271271

272272
object ArrayType {
273-
/** Construct a [[ArrayType]] object with the given element type. The `containsNull` is false. */
274-
def apply(elementType: DataType): ArrayType = ArrayType(elementType, false)
273+
/** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */
274+
def apply(elementType: DataType): ArrayType = ArrayType(elementType, true)
275275
}
276276

277277
/**

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,9 @@ case class OptionalData(
5757

5858
case class ComplexData(
5959
arrayField: Seq[Int],
60-
mapField: Map[Int, String],
60+
arrayFieldContainsNull: Seq[java.lang.Integer],
61+
mapField: Map[Int, Long],
62+
mapFieldValueContainsNull: Map[Int, java.lang.Long],
6163
structField: PrimitiveData)
6264

6365
case class GenericData[A](
@@ -116,8 +118,22 @@ class ScalaReflectionSuite extends FunSuite {
116118
val schema = schemaFor[ComplexData]
117119
assert(schema === Schema(
118120
StructType(Seq(
119-
StructField("arrayField", ArrayType(IntegerType), nullable = true),
120-
StructField("mapField", MapType(IntegerType, StringType), nullable = true),
121+
StructField(
122+
"arrayField",
123+
ArrayType(IntegerType, containsNull = false),
124+
nullable = true),
125+
StructField(
126+
"arrayFieldContainsNull",
127+
ArrayType(IntegerType, containsNull = true),
128+
nullable = true),
129+
StructField(
130+
"mapField",
131+
MapType(IntegerType, LongType, valueContainsNull = false),
132+
nullable = true),
133+
StructField(
134+
"mapFieldValueContainsNull",
135+
MapType(IntegerType, LongType, valueContainsNull = true),
136+
nullable = true),
121137
StructField(
122138
"structField",
123139
StructType(Seq(

sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,14 @@ public abstract class DataType {
8686

8787
/**
8888
* Creates an ArrayType by specifying the data type of elements ({@code elementType}).
89-
* The field of {@code containsNull} is set to {@code false}.
89+
* The field of {@code containsNull} is set to {@code true}.
9090
*/
9191
public static ArrayType createArrayType(DataType elementType) {
9292
if (elementType == null) {
9393
throw new IllegalArgumentException("elementType should not be null.");
9494
}
9595

96-
return new ArrayType(elementType, false);
96+
return new ArrayType(elementType, true);
9797
}
9898

9999
/**

sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class DataTypeSuite extends FunSuite {
2424
test("construct an ArrayType") {
2525
val array = ArrayType(StringType)
2626

27-
assert(ArrayType(StringType, false) === array)
27+
assert(ArrayType(StringType, true) === array)
2828
}
2929

3030
test("construct an MapType") {

sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,11 @@ class JsonSuite extends QueryTest {
130130
checkDataType(
131131
ArrayType(IntegerType, true), ArrayType(IntegerType, true), ArrayType(IntegerType, true))
132132
checkDataType(
133-
ArrayType(IntegerType, false), ArrayType(IntegerType), ArrayType(IntegerType, false))
133+
ArrayType(IntegerType, false), ArrayType(IntegerType), ArrayType(IntegerType, true))
134134
checkDataType(
135135
ArrayType(IntegerType, false), ArrayType(IntegerType, false), ArrayType(IntegerType, false))
136136
checkDataType(
137-
ArrayType(IntegerType, false), ArrayType(IntegerType, false), ArrayType(IntegerType))
137+
ArrayType(IntegerType, false), ArrayType(IntegerType, true), ArrayType(IntegerType, true))
138138

139139
// StructType
140140
checkDataType(StructType(Nil), StructType(Nil), StructType(Nil))
@@ -201,26 +201,26 @@ class JsonSuite extends QueryTest {
201201
val jsonSchemaRDD = jsonRDD(complexFieldAndType)
202202

203203
val expectedSchema = StructType(
204-
StructField("arrayOfArray1", ArrayType(ArrayType(StringType)), true) ::
205-
StructField("arrayOfArray2", ArrayType(ArrayType(DoubleType)), true) ::
206-
StructField("arrayOfBigInteger", ArrayType(DecimalType), true) ::
207-
StructField("arrayOfBoolean", ArrayType(BooleanType), true) ::
208-
StructField("arrayOfDouble", ArrayType(DoubleType), true) ::
209-
StructField("arrayOfInteger", ArrayType(IntegerType), true) ::
210-
StructField("arrayOfLong", ArrayType(LongType), true) ::
204+
StructField("arrayOfArray1", ArrayType(ArrayType(StringType, false), false), true) ::
205+
StructField("arrayOfArray2", ArrayType(ArrayType(DoubleType, false), false), true) ::
206+
StructField("arrayOfBigInteger", ArrayType(DecimalType, false), true) ::
207+
StructField("arrayOfBoolean", ArrayType(BooleanType, false), true) ::
208+
StructField("arrayOfDouble", ArrayType(DoubleType, false), true) ::
209+
StructField("arrayOfInteger", ArrayType(IntegerType, false), true) ::
210+
StructField("arrayOfLong", ArrayType(LongType, false), true) ::
211211
StructField("arrayOfNull", ArrayType(StringType, true), true) ::
212-
StructField("arrayOfString", ArrayType(StringType), true) ::
212+
StructField("arrayOfString", ArrayType(StringType, false), true) ::
213213
StructField("arrayOfStruct", ArrayType(
214214
StructType(
215215
StructField("field1", BooleanType, true) ::
216216
StructField("field2", StringType, true) ::
217-
StructField("field3", StringType, true) :: Nil)), true) ::
217+
StructField("field3", StringType, true) :: Nil), false), true) ::
218218
StructField("struct", StructType(
219219
StructField("field1", BooleanType, true) ::
220220
StructField("field2", DecimalType, true) :: Nil), true) ::
221221
StructField("structWithArrayFields", StructType(
222-
StructField("field1", ArrayType(IntegerType), true) ::
223-
StructField("field2", ArrayType(StringType), true) :: Nil), true) :: Nil)
222+
StructField("field1", ArrayType(IntegerType, false), true) ::
223+
StructField("field2", ArrayType(StringType, false), true) :: Nil), true) :: Nil)
224224

225225
assert(expectedSchema === jsonSchemaRDD.schema)
226226

@@ -441,7 +441,7 @@ class JsonSuite extends QueryTest {
441441
val jsonSchemaRDD = jsonRDD(complexFieldValueTypeConflict)
442442

443443
val expectedSchema = StructType(
444-
StructField("array", ArrayType(IntegerType), true) ::
444+
StructField("array", ArrayType(IntegerType, false), true) ::
445445
StructField("num_struct", StringType, true) ::
446446
StructField("str_array", StringType, true) ::
447447
StructField("struct", StructType(
@@ -467,7 +467,7 @@ class JsonSuite extends QueryTest {
467467
val expectedSchema = StructType(
468468
StructField("array1", ArrayType(StringType, true), true) ::
469469
StructField("array2", ArrayType(StructType(
470-
StructField("field", LongType, true) :: Nil)), true) :: Nil)
470+
StructField("field", LongType, true) :: Nil), false), true) :: Nil)
471471

472472
assert(expectedSchema === jsonSchemaRDD.schema)
473473

@@ -492,7 +492,7 @@ class JsonSuite extends QueryTest {
492492
val expectedSchema = StructType(
493493
StructField("a", BooleanType, true) ::
494494
StructField("b", LongType, true) ::
495-
StructField("c", ArrayType(IntegerType), true) ::
495+
StructField("c", ArrayType(IntegerType, false), true) ::
496496
StructField("d", StructType(
497497
StructField("field", BooleanType, true) :: Nil), true) ::
498498
StructField("e", StringType, true) :: Nil)

0 commit comments

Comments
 (0)