From 489b7328863fd96760ab5433be8f812d3c1281e5 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 3 Jul 2015 23:53:57 -0700 Subject: [PATCH 1/2] [SQL] More unit tests for implicit type cast & add simpleString to AbstractDataType. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 6 ++--- .../spark/sql/types/AbstractDataType.scala | 7 ++++++ .../apache/spark/sql/types/ArrayType.scala | 2 ++ .../apache/spark/sql/types/DecimalType.scala | 2 ++ .../org/apache/spark/sql/types/MapType.scala | 2 ++ .../apache/spark/sql/types/StructType.scala | 2 ++ .../analysis/HiveTypeCoercionSuite.scala | 25 ++++++++++++++++++- 7 files changed, 42 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 583338da5711..aa07d2efe7e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -40,7 +40,7 @@ trait CheckAnalysis { def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = { exprs.flatMap(_.collect { case e: Generator => true - }).length >= 1 + }).nonEmpty } def checkAnalysis(plan: LogicalPlan): Unit = { @@ -85,12 +85,12 @@ trait CheckAnalysis { case Aggregate(groupingExprs, aggregateExprs, child) => def checkValidAggregateExpression(expr: Expression): Unit = expr match { case _: AggregateExpression => // OK - case e: Attribute if groupingExprs.find(_ semanticEquals e).isEmpty => + case e: Attribute if groupingExprs.exists(_.semanticEquals(e)) => failAnalysis( s"expression '${e.prettyString}' is neither present in the group by, " + s"nor is it an aggregate function. " + "Add to group by or wrap in first() if you don't care which value you get.") - case e if groupingExprs.find(_ semanticEquals e).isDefined => // OK + case e if groupingExprs.exists(_.semanticEquals(e)) => // OK case e if e.references.isEmpty => // OK case e => e.children.foreach(checkValidAggregateExpression) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index e5dc99fb625d..ffefb0e7837e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -37,6 +37,9 @@ private[sql] abstract class AbstractDataType { * Returns true if this data type is a parent of the `childCandidate`. */ private[sql] def isParentOf(childCandidate: DataType): Boolean + + /** Readable string representation for the type. */ + private[sql] def simpleString: String } @@ -56,6 +59,10 @@ private[sql] class TypeCollection(private val types: Seq[DataType]) extends Abst private[sql] override def defaultConcreteType: DataType = types.head private[sql] override def isParentOf(childCandidate: DataType): Boolean = false + + private[sql] override def simpleString: String = { + types.map(_.simpleString).mkString("(", " or ", ")") + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 8ea6cb14c360..43413ec761e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -31,6 +31,8 @@ object ArrayType extends AbstractDataType { private[sql] override def isParentOf(childCandidate: DataType): Boolean = { childCandidate.isInstanceOf[ArrayType] } + + private[sql] override def simpleString: String = "array" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 434fc037aad4..127b16ff85be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -90,6 +90,8 @@ object DecimalType extends AbstractDataType { childCandidate.isInstanceOf[DecimalType] } + private[sql] override def simpleString: String = "decimal" + val Unlimited: DecimalType = DecimalType(None) private[sql] object Fixed { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index 2b25617ec665..868dea13d971 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -75,6 +75,8 @@ object MapType extends AbstractDataType { childCandidate.isInstanceOf[MapType] } + private[sql] override def simpleString: String = "map" + /** * Construct a [[MapType]] object with the given key type and value type. * The `valueContainsNull` is true. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 7e77b77e7394..3b17566d54d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -309,6 +309,8 @@ object StructType extends AbstractDataType { childCandidate.isInstanceOf[StructType] } + private[sql] override def simpleString: String = "struct" + def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray) def apply(fields: java.util.List[StructField]): StructType = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 60e727c6c7d4..67d05ab536b7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.types._ class HiveTypeCoercionSuite extends PlanTest { - test("implicit type cast") { + test("eligible implicit type cast") { def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = { val got = HiveTypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to) assert(got.map(_.dataType) == Option(expected), @@ -68,6 +68,29 @@ class HiveTypeCoercionSuite extends PlanTest { shouldCast(IntegerType, TypeCollection(BinaryType, IntegerType), IntegerType) shouldCast(BinaryType, TypeCollection(BinaryType, IntegerType), BinaryType) shouldCast(BinaryType, TypeCollection(IntegerType, BinaryType), BinaryType) + + shouldCast(IntegerType, TypeCollection(StringType, BinaryType), StringType) + shouldCast(IntegerType, TypeCollection(BinaryType, StringType), StringType) + } + + test("ineligible implicit type cast") { + def shouldNotCast(from: DataType, to: AbstractDataType): Unit = { + val got = HiveTypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to) + assert(got.isEmpty, s"Should not be able to cast $from to $to, but got $got") + } + + shouldNotCast(IntegerType, DateType) + shouldNotCast(IntegerType, TimestampType) + shouldNotCast(LongType, DateType) + shouldNotCast(LongType, TimestampType) + shouldNotCast(DecimalType.Unlimited, DateType) + shouldNotCast(DecimalType.Unlimited, TimestampType) + + shouldNotCast(IntegerType, TypeCollection(DateType, TimestampType)) + + shouldNotCast(IntegerType, ArrayType) + shouldNotCast(IntegerType, MapType) + shouldNotCast(IntegerType, StructType) } test("tightest common bound for types") { From 64b13bda117767e4d3c1c2e7cabb8f8574c88438 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 4 Jul 2015 00:33:17 -0700 Subject: [PATCH 2/2] Fixed a bug .. --- .../org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index aa07d2efe7e9..476ac2b7cb47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -85,7 +85,7 @@ trait CheckAnalysis { case Aggregate(groupingExprs, aggregateExprs, child) => def checkValidAggregateExpression(expr: Expression): Unit = expr match { case _: AggregateExpression => // OK - case e: Attribute if groupingExprs.exists(_.semanticEquals(e)) => + case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) => failAnalysis( s"expression '${e.prettyString}' is neither present in the group by, " + s"nor is it an aggregate function. " +