Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ trait CheckAnalysis {
def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = {
exprs.flatMap(_.collect {
case e: Generator => true
}).length >= 1
}).nonEmpty
}

def checkAnalysis(plan: LogicalPlan): Unit = {
Expand Down Expand Up @@ -85,12 +85,12 @@ trait CheckAnalysis {
case Aggregate(groupingExprs, aggregateExprs, child) =>
def checkValidAggregateExpression(expr: Expression): Unit = expr match {
case _: AggregateExpression => // OK
case e: Attribute if groupingExprs.find(_ semanticEquals e).isEmpty =>
case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) =>
failAnalysis(
s"expression '${e.prettyString}' is neither present in the group by, " +
s"nor is it an aggregate function. " +
"Add to group by or wrap in first() if you don't care which value you get.")
case e if groupingExprs.find(_ semanticEquals e).isDefined => // OK
case e if groupingExprs.exists(_.semanticEquals(e)) => // OK
case e if e.references.isEmpty => // OK
case e => e.children.foreach(checkValidAggregateExpression)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ private[sql] abstract class AbstractDataType {
* Returns true if this data type is a parent of the `childCandidate`.
*/
private[sql] def isParentOf(childCandidate: DataType): Boolean

/** Readable string representation for the type. */
private[sql] def simpleString: String
}


Expand All @@ -56,6 +59,10 @@ private[sql] class TypeCollection(private val types: Seq[DataType]) extends Abst
private[sql] override def defaultConcreteType: DataType = types.head

private[sql] override def isParentOf(childCandidate: DataType): Boolean = false

private[sql] override def simpleString: String = {
types.map(_.simpleString).mkString("(", " or ", ")")
}
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ object ArrayType extends AbstractDataType {
private[sql] override def isParentOf(childCandidate: DataType): Boolean = {
childCandidate.isInstanceOf[ArrayType]
}

private[sql] override def simpleString: String = "array"
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ object DecimalType extends AbstractDataType {
childCandidate.isInstanceOf[DecimalType]
}

private[sql] override def simpleString: String = "decimal"

val Unlimited: DecimalType = DecimalType(None)

private[sql] object Fixed {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ object MapType extends AbstractDataType {
childCandidate.isInstanceOf[MapType]
}

private[sql] override def simpleString: String = "map"

/**
* Construct a [[MapType]] object with the given key type and value type.
* The `valueContainsNull` is true.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,8 @@ object StructType extends AbstractDataType {
childCandidate.isInstanceOf[StructType]
}

private[sql] override def simpleString: String = "struct"

def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray)

def apply(fields: java.util.List[StructField]): StructType = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.types._

class HiveTypeCoercionSuite extends PlanTest {

test("implicit type cast") {
test("eligible implicit type cast") {
def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = {
val got = HiveTypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to)
assert(got.map(_.dataType) == Option(expected),
Expand Down Expand Up @@ -68,6 +68,29 @@ class HiveTypeCoercionSuite extends PlanTest {
shouldCast(IntegerType, TypeCollection(BinaryType, IntegerType), IntegerType)
shouldCast(BinaryType, TypeCollection(BinaryType, IntegerType), BinaryType)
shouldCast(BinaryType, TypeCollection(IntegerType, BinaryType), BinaryType)

shouldCast(IntegerType, TypeCollection(StringType, BinaryType), StringType)
shouldCast(IntegerType, TypeCollection(BinaryType, StringType), StringType)
}

test("ineligible implicit type cast") {
def shouldNotCast(from: DataType, to: AbstractDataType): Unit = {
val got = HiveTypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to)
assert(got.isEmpty, s"Should not be able to cast $from to $to, but got $got")
}

shouldNotCast(IntegerType, DateType)
shouldNotCast(IntegerType, TimestampType)
shouldNotCast(LongType, DateType)
shouldNotCast(LongType, TimestampType)
shouldNotCast(DecimalType.Unlimited, DateType)
shouldNotCast(DecimalType.Unlimited, TimestampType)

shouldNotCast(IntegerType, TypeCollection(DateType, TimestampType))

shouldNotCast(IntegerType, ArrayType)
shouldNotCast(IntegerType, MapType)
shouldNotCast(IntegerType, StructType)
}

test("tightest common bound for types") {
Expand Down