Skip to content

Commit 347cab8

Browse files
committed
[SQL] More unit tests for implicit type cast & add simpleString to AbstractDataType.
Author: Reynold Xin <[email protected]> Closes #7221 from rxin/implicit-cast-tests and squashes the following commits: 64b13bd [Reynold Xin] Fixed a bug .. 489b732 [Reynold Xin] [SQL] More unit tests for implicit type cast & add simpleString to AbstractDataType.
1 parent 48f7aed commit 347cab8

File tree

7 files changed

+42
-4
lines changed

7 files changed

+42
-4
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ trait CheckAnalysis {
4040
def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = {
4141
exprs.flatMap(_.collect {
4242
case e: Generator => true
43-
}).length >= 1
43+
}).nonEmpty
4444
}
4545

4646
def checkAnalysis(plan: LogicalPlan): Unit = {
@@ -85,12 +85,12 @@ trait CheckAnalysis {
8585
case Aggregate(groupingExprs, aggregateExprs, child) =>
8686
def checkValidAggregateExpression(expr: Expression): Unit = expr match {
8787
case _: AggregateExpression => // OK
88-
case e: Attribute if groupingExprs.find(_ semanticEquals e).isEmpty =>
88+
case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) =>
8989
failAnalysis(
9090
s"expression '${e.prettyString}' is neither present in the group by, " +
9191
s"nor is it an aggregate function. " +
9292
"Add to group by or wrap in first() if you don't care which value you get.")
93-
case e if groupingExprs.find(_ semanticEquals e).isDefined => // OK
93+
case e if groupingExprs.exists(_.semanticEquals(e)) => // OK
9494
case e if e.references.isEmpty => // OK
9595
case e => e.children.foreach(checkValidAggregateExpression)
9696
}

sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ private[sql] abstract class AbstractDataType {
3737
* Returns true if this data type is a parent of the `childCandidate`.
3838
*/
3939
private[sql] def isParentOf(childCandidate: DataType): Boolean
40+
41+
/** Readable string representation for the type. */
42+
private[sql] def simpleString: String
4043
}
4144

4245

@@ -56,6 +59,10 @@ private[sql] class TypeCollection(private val types: Seq[DataType]) extends Abst
5659
private[sql] override def defaultConcreteType: DataType = types.head
5760

5861
private[sql] override def isParentOf(childCandidate: DataType): Boolean = false
62+
63+
private[sql] override def simpleString: String = {
64+
types.map(_.simpleString).mkString("(", " or ", ")")
65+
}
5966
}
6067

6168

sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ object ArrayType extends AbstractDataType {
3131
private[sql] override def isParentOf(childCandidate: DataType): Boolean = {
3232
childCandidate.isInstanceOf[ArrayType]
3333
}
34+
35+
private[sql] override def simpleString: String = "array"
3436
}
3537

3638

sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ object DecimalType extends AbstractDataType {
9090
childCandidate.isInstanceOf[DecimalType]
9191
}
9292

93+
private[sql] override def simpleString: String = "decimal"
94+
9395
val Unlimited: DecimalType = DecimalType(None)
9496

9597
private[sql] object Fixed {

sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ object MapType extends AbstractDataType {
7575
childCandidate.isInstanceOf[MapType]
7676
}
7777

78+
private[sql] override def simpleString: String = "map"
79+
7880
/**
7981
* Construct a [[MapType]] object with the given key type and value type.
8082
* The `valueContainsNull` is true.

sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,8 @@ object StructType extends AbstractDataType {
309309
childCandidate.isInstanceOf[StructType]
310310
}
311311

312+
private[sql] override def simpleString: String = "struct"
313+
312314
def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray)
313315

314316
def apply(fields: java.util.List[StructField]): StructType = {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.apache.spark.sql.types._
2626

2727
class HiveTypeCoercionSuite extends PlanTest {
2828

29-
test("implicit type cast") {
29+
test("eligible implicit type cast") {
3030
def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = {
3131
val got = HiveTypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to)
3232
assert(got.map(_.dataType) == Option(expected),
@@ -68,6 +68,29 @@ class HiveTypeCoercionSuite extends PlanTest {
6868
shouldCast(IntegerType, TypeCollection(BinaryType, IntegerType), IntegerType)
6969
shouldCast(BinaryType, TypeCollection(BinaryType, IntegerType), BinaryType)
7070
shouldCast(BinaryType, TypeCollection(IntegerType, BinaryType), BinaryType)
71+
72+
shouldCast(IntegerType, TypeCollection(StringType, BinaryType), StringType)
73+
shouldCast(IntegerType, TypeCollection(BinaryType, StringType), StringType)
74+
}
75+
76+
test("ineligible implicit type cast") {
77+
def shouldNotCast(from: DataType, to: AbstractDataType): Unit = {
78+
val got = HiveTypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to)
79+
assert(got.isEmpty, s"Should not be able to cast $from to $to, but got $got")
80+
}
81+
82+
shouldNotCast(IntegerType, DateType)
83+
shouldNotCast(IntegerType, TimestampType)
84+
shouldNotCast(LongType, DateType)
85+
shouldNotCast(LongType, TimestampType)
86+
shouldNotCast(DecimalType.Unlimited, DateType)
87+
shouldNotCast(DecimalType.Unlimited, TimestampType)
88+
89+
shouldNotCast(IntegerType, TypeCollection(DateType, TimestampType))
90+
91+
shouldNotCast(IntegerType, ArrayType)
92+
shouldNotCast(IntegerType, MapType)
93+
shouldNotCast(IntegerType, StructType)
7194
}
7295

7396
test("tightest common bound for types") {

0 commit comments

Comments
 (0)