Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,15 @@ class ArrayType(DataType):

"""

def __init__(self, elementType, containsNull=False):
def __init__(self, elementType, containsNull=True):
"""Creates an ArrayType

:param elementType: the data type of elements.
:param containsNull: indicates whether the list contains None values.

>>> ArrayType(StringType) == ArrayType(StringType, False)
>>> ArrayType(StringType) == ArrayType(StringType, True)
True
>>> ArrayType(StringType, True) == ArrayType(StringType)
>>> ArrayType(StringType, False) == ArrayType(StringType)
False
"""
self.elementType = elementType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,14 @@ object ScalaReflection {
sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
case t if t <:< typeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
Schema(ArrayType(schemaFor(elementType).dataType), nullable = true)
val Schema(dataType, nullable) = schemaFor(elementType)
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
case t if t <:< typeOf[Map[_,_]] =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
Schema(MapType(schemaFor(keyType).dataType, schemaFor(valueType).dataType), nullable = true)
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
val Schema(valueDataType, valueNullable) = schemaFor(valueType)
Schema(MapType(schemaFor(keyType).dataType,
valueDataType, valueContainsNull = valueNullable), nullable = true)
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
case t if t <:< typeOf[BigDecimal] => Schema(DecimalType, nullable = true)
case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,8 @@ case object FloatType extends FractionalType {
}

object ArrayType {
/** Construct a [[ArrayType]] object with the given element type. The `containsNull` is false. */
def apply(elementType: DataType): ArrayType = ArrayType(elementType, false)
/** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */
def apply(elementType: DataType): ArrayType = ArrayType(elementType, true)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ case class OptionalData(

case class ComplexData(
arrayField: Seq[Int],
mapField: Map[Int, String],
arrayFieldContainsNull: Seq[java.lang.Integer],
mapField: Map[Int, Long],
mapFieldValueContainsNull: Map[Int, java.lang.Long],
structField: PrimitiveData)

case class GenericData[A](
Expand Down Expand Up @@ -116,8 +118,22 @@ class ScalaReflectionSuite extends FunSuite {
val schema = schemaFor[ComplexData]
assert(schema === Schema(
StructType(Seq(
StructField("arrayField", ArrayType(IntegerType), nullable = true),
StructField("mapField", MapType(IntegerType, StringType), nullable = true),
StructField(
"arrayField",
ArrayType(IntegerType, containsNull = false),
nullable = true),
StructField(
"arrayFieldContainsNull",
ArrayType(IntegerType, containsNull = true),
nullable = true),
StructField(
"mapField",
MapType(IntegerType, LongType, valueContainsNull = false),
nullable = true),
StructField(
"mapFieldValueContainsNull",
MapType(IntegerType, LongType, valueContainsNull = true),
nullable = true),
StructField(
"structField",
StructType(Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,14 @@ public abstract class DataType {

/**
* Creates an ArrayType by specifying the data type of elements ({@code elementType}).
* The field of {@code containsNull} is set to {@code false}.
* The field of {@code containsNull} is set to {@code true}.
*/
public static ArrayType createArrayType(DataType elementType) {
if (elementType == null) {
throw new IllegalArgumentException("elementType should not be null.");
}

return new ArrayType(elementType, false);
return new ArrayType(elementType, true);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class DataTypeSuite extends FunSuite {
test("construct an ArrayType") {
val array = ArrayType(StringType)

assert(ArrayType(StringType, false) === array)
assert(ArrayType(StringType, true) === array)
}

test("construct an MapType") {
Expand Down
32 changes: 16 additions & 16 deletions sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,11 @@ class JsonSuite extends QueryTest {
checkDataType(
ArrayType(IntegerType, true), ArrayType(IntegerType, true), ArrayType(IntegerType, true))
checkDataType(
ArrayType(IntegerType, false), ArrayType(IntegerType), ArrayType(IntegerType, false))
ArrayType(IntegerType, false), ArrayType(IntegerType), ArrayType(IntegerType, true))
checkDataType(
ArrayType(IntegerType, false), ArrayType(IntegerType, false), ArrayType(IntegerType, false))
checkDataType(
ArrayType(IntegerType, false), ArrayType(IntegerType, false), ArrayType(IntegerType))
ArrayType(IntegerType, false), ArrayType(IntegerType, true), ArrayType(IntegerType, true))

// StructType
checkDataType(StructType(Nil), StructType(Nil), StructType(Nil))
Expand Down Expand Up @@ -201,26 +201,26 @@ class JsonSuite extends QueryTest {
val jsonSchemaRDD = jsonRDD(complexFieldAndType)

val expectedSchema = StructType(
StructField("arrayOfArray1", ArrayType(ArrayType(StringType)), true) ::
StructField("arrayOfArray2", ArrayType(ArrayType(DoubleType)), true) ::
StructField("arrayOfBigInteger", ArrayType(DecimalType), true) ::
StructField("arrayOfBoolean", ArrayType(BooleanType), true) ::
StructField("arrayOfDouble", ArrayType(DoubleType), true) ::
StructField("arrayOfInteger", ArrayType(IntegerType), true) ::
StructField("arrayOfLong", ArrayType(LongType), true) ::
StructField("arrayOfArray1", ArrayType(ArrayType(StringType, false), false), true) ::
StructField("arrayOfArray2", ArrayType(ArrayType(DoubleType, false), false), true) ::
StructField("arrayOfBigInteger", ArrayType(DecimalType, false), true) ::
StructField("arrayOfBoolean", ArrayType(BooleanType, false), true) ::
StructField("arrayOfDouble", ArrayType(DoubleType, false), true) ::
StructField("arrayOfInteger", ArrayType(IntegerType, false), true) ::
StructField("arrayOfLong", ArrayType(LongType, false), true) ::
StructField("arrayOfNull", ArrayType(StringType, true), true) ::
StructField("arrayOfString", ArrayType(StringType), true) ::
StructField("arrayOfString", ArrayType(StringType, false), true) ::
StructField("arrayOfStruct", ArrayType(
StructType(
StructField("field1", BooleanType, true) ::
StructField("field2", StringType, true) ::
StructField("field3", StringType, true) :: Nil)), true) ::
StructField("field3", StringType, true) :: Nil), false), true) ::
StructField("struct", StructType(
StructField("field1", BooleanType, true) ::
StructField("field2", DecimalType, true) :: Nil), true) ::
StructField("structWithArrayFields", StructType(
StructField("field1", ArrayType(IntegerType), true) ::
StructField("field2", ArrayType(StringType), true) :: Nil), true) :: Nil)
StructField("field1", ArrayType(IntegerType, false), true) ::
StructField("field2", ArrayType(StringType, false), true) :: Nil), true) :: Nil)

assert(expectedSchema === jsonSchemaRDD.schema)

Expand Down Expand Up @@ -441,7 +441,7 @@ class JsonSuite extends QueryTest {
val jsonSchemaRDD = jsonRDD(complexFieldValueTypeConflict)

val expectedSchema = StructType(
StructField("array", ArrayType(IntegerType), true) ::
StructField("array", ArrayType(IntegerType, false), true) ::
StructField("num_struct", StringType, true) ::
StructField("str_array", StringType, true) ::
StructField("struct", StructType(
Expand All @@ -467,7 +467,7 @@ class JsonSuite extends QueryTest {
val expectedSchema = StructType(
StructField("array1", ArrayType(StringType, true), true) ::
StructField("array2", ArrayType(StructType(
StructField("field", LongType, true) :: Nil)), true) :: Nil)
StructField("field", LongType, true) :: Nil), false), true) :: Nil)

assert(expectedSchema === jsonSchemaRDD.schema)

Expand All @@ -492,7 +492,7 @@ class JsonSuite extends QueryTest {
val expectedSchema = StructType(
StructField("a", BooleanType, true) ::
StructField("b", LongType, true) ::
StructField("c", ArrayType(IntegerType), true) ::
StructField("c", ArrayType(IntegerType, false), true) ::
StructField("d", StructType(
StructField("field", BooleanType, true) :: Nil), true) ::
StructField("e", StringType, true) :: Nil)
Expand Down