Skip to content

Commit e4899a2

Browse files
ueshinrxin
authored andcommitted
[SPARK-2254] [SQL] ScalaRefection should mark primitive types as non-nullable.
Author: Takuya UESHIN <[email protected]> Closes #1193 from ueshin/issues/SPARK-2254 and squashes the following commits: cfd6088 [Takuya UESHIN] Modify ScalaRefection.schemaFor method to return nullability of Scala Type.
1 parent 441cdcc commit e4899a2

File tree

2 files changed

+165
-31
lines changed

2 files changed

+165
-31
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -30,53 +30,56 @@ import org.apache.spark.sql.catalyst.types._
3030
object ScalaReflection {
3131
import scala.reflect.runtime.universe._
3232

33+
case class Schema(dataType: DataType, nullable: Boolean)
34+
3335
/** Returns a Sequence of attributes for the given case class type. */
3436
def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match {
35-
case s: StructType =>
36-
s.fields.map(f => AttributeReference(f.name, f.dataType, nullable = true)())
37+
case Schema(s: StructType, _) =>
38+
s.fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)())
3739
}
3840

39-
/** Returns a catalyst DataType for the given Scala Type using reflection. */
40-
def schemaFor[T: TypeTag]: DataType = schemaFor(typeOf[T])
41+
/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
42+
def schemaFor[T: TypeTag]: Schema = schemaFor(typeOf[T])
4143

42-
/** Returns a catalyst DataType for the given Scala Type using reflection. */
43-
def schemaFor(tpe: `Type`): DataType = tpe match {
44+
/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
45+
def schemaFor(tpe: `Type`): Schema = tpe match {
4446
case t if t <:< typeOf[Option[_]] =>
4547
val TypeRef(_, _, Seq(optType)) = t
46-
schemaFor(optType)
48+
Schema(schemaFor(optType).dataType, nullable = true)
4749
case t if t <:< typeOf[Product] =>
4850
val params = t.member("<init>": TermName).asMethod.paramss
49-
StructType(
50-
params.head.map(p =>
51-
StructField(p.name.toString, schemaFor(p.typeSignature), nullable = true)))
51+
Schema(StructType(
52+
params.head.map { p =>
53+
val Schema(dataType, nullable) = schemaFor(p.typeSignature)
54+
StructField(p.name.toString, dataType, nullable)
55+
}), nullable = true)
5256
// Need to decide if we actually need a special type here.
53-
case t if t <:< typeOf[Array[Byte]] => BinaryType
57+
case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
5458
case t if t <:< typeOf[Array[_]] =>
5559
sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
5660
case t if t <:< typeOf[Seq[_]] =>
5761
val TypeRef(_, _, Seq(elementType)) = t
58-
ArrayType(schemaFor(elementType))
62+
Schema(ArrayType(schemaFor(elementType).dataType), nullable = true)
5963
case t if t <:< typeOf[Map[_,_]] =>
6064
val TypeRef(_, _, Seq(keyType, valueType)) = t
61-
MapType(schemaFor(keyType), schemaFor(valueType))
62-
case t if t <:< typeOf[String] => StringType
63-
case t if t <:< typeOf[Timestamp] => TimestampType
64-
case t if t <:< typeOf[BigDecimal] => DecimalType
65-
case t if t <:< typeOf[java.lang.Integer] => IntegerType
66-
case t if t <:< typeOf[java.lang.Long] => LongType
67-
case t if t <:< typeOf[java.lang.Double] => DoubleType
68-
case t if t <:< typeOf[java.lang.Float] => FloatType
69-
case t if t <:< typeOf[java.lang.Short] => ShortType
70-
case t if t <:< typeOf[java.lang.Byte] => ByteType
71-
case t if t <:< typeOf[java.lang.Boolean] => BooleanType
72-
// TODO: The following datatypes could be marked as non-nullable.
73-
case t if t <:< definitions.IntTpe => IntegerType
74-
case t if t <:< definitions.LongTpe => LongType
75-
case t if t <:< definitions.DoubleTpe => DoubleType
76-
case t if t <:< definitions.FloatTpe => FloatType
77-
case t if t <:< definitions.ShortTpe => ShortType
78-
case t if t <:< definitions.ByteTpe => ByteType
79-
case t if t <:< definitions.BooleanTpe => BooleanType
65+
Schema(MapType(schemaFor(keyType).dataType, schemaFor(valueType).dataType), nullable = true)
66+
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
67+
case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
68+
case t if t <:< typeOf[BigDecimal] => Schema(DecimalType, nullable = true)
69+
case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
70+
case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true)
71+
case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true)
72+
case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true)
73+
case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true)
74+
case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true)
75+
case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true)
76+
case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false)
77+
case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false)
78+
case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false)
79+
case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false)
80+
case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false)
81+
case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false)
82+
case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false)
8083
}
8184

8285
implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) {
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst
19+
20+
import java.sql.Timestamp
21+
22+
import org.scalatest.FunSuite
23+
24+
import org.apache.spark.sql.catalyst.expressions._
25+
import org.apache.spark.sql.catalyst.types._
26+
27+
case class PrimitiveData(
28+
intField: Int,
29+
longField: Long,
30+
doubleField: Double,
31+
floatField: Float,
32+
shortField: Short,
33+
byteField: Byte,
34+
booleanField: Boolean)
35+
36+
case class NullableData(
37+
intField: java.lang.Integer,
38+
longField: java.lang.Long,
39+
doubleField: java.lang.Double,
40+
floatField: java.lang.Float,
41+
shortField: java.lang.Short,
42+
byteField: java.lang.Byte,
43+
booleanField: java.lang.Boolean,
44+
stringField: String,
45+
decimalField: BigDecimal,
46+
timestampField: Timestamp,
47+
binaryField: Array[Byte])
48+
49+
case class OptionalData(
50+
intField: Option[Int],
51+
longField: Option[Long],
52+
doubleField: Option[Double],
53+
floatField: Option[Float],
54+
shortField: Option[Short],
55+
byteField: Option[Byte],
56+
booleanField: Option[Boolean])
57+
58+
case class ComplexData(
59+
arrayField: Seq[Int],
60+
mapField: Map[Int, String],
61+
structField: PrimitiveData)
62+
63+
class ScalaReflectionSuite extends FunSuite {
64+
import ScalaReflection._
65+
66+
test("primitive data") {
67+
val schema = schemaFor[PrimitiveData]
68+
assert(schema === Schema(
69+
StructType(Seq(
70+
StructField("intField", IntegerType, nullable = false),
71+
StructField("longField", LongType, nullable = false),
72+
StructField("doubleField", DoubleType, nullable = false),
73+
StructField("floatField", FloatType, nullable = false),
74+
StructField("shortField", ShortType, nullable = false),
75+
StructField("byteField", ByteType, nullable = false),
76+
StructField("booleanField", BooleanType, nullable = false))),
77+
nullable = true))
78+
}
79+
80+
test("nullable data") {
81+
val schema = schemaFor[NullableData]
82+
assert(schema === Schema(
83+
StructType(Seq(
84+
StructField("intField", IntegerType, nullable = true),
85+
StructField("longField", LongType, nullable = true),
86+
StructField("doubleField", DoubleType, nullable = true),
87+
StructField("floatField", FloatType, nullable = true),
88+
StructField("shortField", ShortType, nullable = true),
89+
StructField("byteField", ByteType, nullable = true),
90+
StructField("booleanField", BooleanType, nullable = true),
91+
StructField("stringField", StringType, nullable = true),
92+
StructField("decimalField", DecimalType, nullable = true),
93+
StructField("timestampField", TimestampType, nullable = true),
94+
StructField("binaryField", BinaryType, nullable = true))),
95+
nullable = true))
96+
}
97+
98+
test("optinal data") {
99+
val schema = schemaFor[OptionalData]
100+
assert(schema === Schema(
101+
StructType(Seq(
102+
StructField("intField", IntegerType, nullable = true),
103+
StructField("longField", LongType, nullable = true),
104+
StructField("doubleField", DoubleType, nullable = true),
105+
StructField("floatField", FloatType, nullable = true),
106+
StructField("shortField", ShortType, nullable = true),
107+
StructField("byteField", ByteType, nullable = true),
108+
StructField("booleanField", BooleanType, nullable = true))),
109+
nullable = true))
110+
}
111+
112+
test("complex data") {
113+
val schema = schemaFor[ComplexData]
114+
assert(schema === Schema(
115+
StructType(Seq(
116+
StructField("arrayField", ArrayType(IntegerType), nullable = true),
117+
StructField("mapField", MapType(IntegerType, StringType), nullable = true),
118+
StructField(
119+
"structField",
120+
StructType(Seq(
121+
StructField("intField", IntegerType, nullable = false),
122+
StructField("longField", LongType, nullable = false),
123+
StructField("doubleField", DoubleType, nullable = false),
124+
StructField("floatField", FloatType, nullable = false),
125+
StructField("shortField", ShortType, nullable = false),
126+
StructField("byteField", ByteType, nullable = false),
127+
StructField("booleanField", BooleanType, nullable = false))),
128+
nullable = true))),
129+
nullable = true))
130+
}
131+
}

0 commit comments

Comments
 (0)