Skip to content

Commit 4137f76

Browse files
committed
[SPARK-8752][SQL] Add ExpectsInputTypes trait for defining expected input types.
This patch doesn't actually introduce any code that uses the new ExpectsInputTypes. It just adds the trait so others can use it. Also renamed the old expectsInputTypes function to just inputTypes. We should add implicit type casting also in the future. Author: Reynold Xin <[email protected]> Closes apache#7151 from rxin/expects-input-types and squashes the following commits: 16cf07b [Reynold Xin] [SPARK-8752][SQL] Add ExpectsInputTypes trait for defining expected input types.
1 parent 69c5dee commit 4137f76

File tree

7 files changed

+44
-24
lines changed

7 files changed

+44
-24
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ import org.apache.spark.sql.types._
2626
* Throws user facing errors when passed invalid queries that fail to analyze.
2727
*/
2828
trait CheckAnalysis {
29-
self: Analyzer =>
3029

3130
/**
3231
* Override to provide additional checks for correct analysis.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ object HiveTypeCoercion {
4545
IfCoercion ::
4646
Division ::
4747
PropagateTypes ::
48-
AddCastForAutoCastInputTypes ::
48+
ImplicitTypeCasts ::
4949
Nil
5050

5151
// See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types.
@@ -705,13 +705,13 @@ object HiveTypeCoercion {
705705
* Casts types according to the expected input types for Expressions that have the trait
706706
* [[AutoCastInputTypes]].
707707
*/
708-
object AddCastForAutoCastInputTypes extends Rule[LogicalPlan] {
708+
object ImplicitTypeCasts extends Rule[LogicalPlan] {
709709
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
710710
// Skip nodes who's children have not been resolved yet.
711711
case e if !e.childrenResolved => e
712712

713-
case e: AutoCastInputTypes if e.children.map(_.dataType) != e.expectedChildTypes =>
714-
val newC = (e.children, e.children.map(_.dataType), e.expectedChildTypes).zipped.map {
713+
case e: AutoCastInputTypes if e.children.map(_.dataType) != e.inputTypes =>
714+
val newC = (e.children, e.children.map(_.dataType), e.inputTypes).zipped.map {
715715
case (child, actual, expected) =>
716716
if (actual == expected) child else Cast(child, expected)
717717
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -265,17 +265,38 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio
265265
}
266266
}
267267

268+
/**
269+
* An trait that gets mixin to define the expected input types of an expression.
270+
*/
271+
trait ExpectsInputTypes { self: Expression =>
272+
273+
/**
274+
* Expected input types from child expressions. The i-th position in the returned seq indicates
275+
* the type requirement for the i-th child.
276+
*
277+
* The possible values at each position are:
278+
* 1. a specific data type, e.g. LongType, StringType.
279+
* 2. a non-leaf data type, e.g. NumericType, IntegralType, FractionalType.
280+
* 3. a list of specific data types, e.g. Seq(StringType, BinaryType).
281+
*/
282+
def inputTypes: Seq[Any]
283+
284+
override def checkInputDataTypes(): TypeCheckResult = {
285+
// We will do the type checking in `HiveTypeCoercion`, so always returning success here.
286+
TypeCheckResult.TypeCheckSuccess
287+
}
288+
}
289+
268290
/**
269291
* Expressions that require a specific `DataType` as input should implement this trait
270292
* so that the proper type conversions can be performed in the analyzer.
271293
*/
272-
trait AutoCastInputTypes {
273-
self: Expression =>
294+
trait AutoCastInputTypes { self: Expression =>
274295

275-
def expectedChildTypes: Seq[DataType]
296+
def inputTypes: Seq[DataType]
276297

277298
override def checkInputDataTypes(): TypeCheckResult = {
278-
// We will always do type casting for `ExpectsInputTypes` in `HiveTypeCoercion`,
299+
// We will always do type casting for `AutoCastInputTypes` in `HiveTypeCoercion`,
279300
// so type mismatch error won't be reported here, but for underling `Cast`s.
280301
TypeCheckResult.TypeCheckSuccess
281302
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String)
5959
extends UnaryExpression with Serializable with AutoCastInputTypes {
6060
self: Product =>
6161

62-
override def expectedChildTypes: Seq[DataType] = Seq(DoubleType)
62+
override def inputTypes: Seq[DataType] = Seq(DoubleType)
6363
override def dataType: DataType = DoubleType
6464
override def nullable: Boolean = true
6565
override def toString: String = s"$name($child)"
@@ -98,7 +98,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String)
9898
abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
9999
extends BinaryExpression with Serializable with AutoCastInputTypes { self: Product =>
100100

101-
override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType)
101+
override def inputTypes: Seq[DataType] = Seq(DoubleType, DoubleType)
102102

103103
override def toString: String = s"$name($left, $right)"
104104

@@ -210,7 +210,7 @@ case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadia
210210
case class Bin(child: Expression)
211211
extends UnaryExpression with Serializable with AutoCastInputTypes {
212212

213-
override def expectedChildTypes: Seq[DataType] = Seq(LongType)
213+
override def inputTypes: Seq[DataType] = Seq(LongType)
214214
override def dataType: DataType = StringType
215215

216216
override def eval(input: InternalRow): Any = {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ case class Md5(child: Expression)
3636

3737
override def dataType: DataType = StringType
3838

39-
override def expectedChildTypes: Seq[DataType] = Seq(BinaryType)
39+
override def inputTypes: Seq[DataType] = Seq(BinaryType)
4040

4141
override def eval(input: InternalRow): Any = {
4242
val value = child.eval(input)
@@ -68,7 +68,7 @@ case class Sha2(left: Expression, right: Expression)
6868

6969
override def toString: String = s"SHA2($left, $right)"
7070

71-
override def expectedChildTypes: Seq[DataType] = Seq(BinaryType, IntegerType)
71+
override def inputTypes: Seq[DataType] = Seq(BinaryType, IntegerType)
7272

7373
override def eval(input: InternalRow): Any = {
7474
val evalE1 = left.eval(input)
@@ -151,7 +151,7 @@ case class Sha1(child: Expression) extends UnaryExpression with AutoCastInputTyp
151151

152152
override def dataType: DataType = StringType
153153

154-
override def expectedChildTypes: Seq[DataType] = Seq(BinaryType)
154+
override def inputTypes: Seq[DataType] = Seq(BinaryType)
155155

156156
override def eval(input: InternalRow): Any = {
157157
val value = child.eval(input)
@@ -179,7 +179,7 @@ case class Crc32(child: Expression)
179179

180180
override def dataType: DataType = LongType
181181

182-
override def expectedChildTypes: Seq[DataType] = Seq(BinaryType)
182+
override def inputTypes: Seq[DataType] = Seq(BinaryType)
183183

184184
override def eval(input: InternalRow): Any = {
185185
val value = child.eval(input)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ trait PredicateHelper {
7272
case class Not(child: Expression) extends UnaryExpression with Predicate with AutoCastInputTypes {
7373
override def toString: String = s"NOT $child"
7474

75-
override def expectedChildTypes: Seq[DataType] = Seq(BooleanType)
75+
override def inputTypes: Seq[DataType] = Seq(BooleanType)
7676

7777
override def eval(input: InternalRow): Any = {
7878
child.eval(input) match {
@@ -122,7 +122,7 @@ case class InSet(value: Expression, hset: Set[Any])
122122
case class And(left: Expression, right: Expression)
123123
extends BinaryExpression with Predicate with AutoCastInputTypes {
124124

125-
override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
125+
override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
126126

127127
override def symbol: String = "&&"
128128

@@ -171,7 +171,7 @@ case class And(left: Expression, right: Expression)
171171
case class Or(left: Expression, right: Expression)
172172
extends BinaryExpression with Predicate with AutoCastInputTypes {
173173

174-
override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
174+
override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
175175

176176
override def symbol: String = "||"
177177

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ trait StringRegexExpression extends AutoCastInputTypes {
3232

3333
override def nullable: Boolean = left.nullable || right.nullable
3434
override def dataType: DataType = BooleanType
35-
override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
35+
override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
3636

3737
// try cache the pattern for Literal
3838
private lazy val cache: Pattern = right match {
@@ -117,7 +117,7 @@ trait CaseConversionExpression extends AutoCastInputTypes {
117117
def convert(v: UTF8String): UTF8String
118118

119119
override def dataType: DataType = StringType
120-
override def expectedChildTypes: Seq[DataType] = Seq(StringType)
120+
override def inputTypes: Seq[DataType] = Seq(StringType)
121121

122122
override def eval(input: InternalRow): Any = {
123123
val evaluated = child.eval(input)
@@ -165,7 +165,7 @@ trait StringComparison extends AutoCastInputTypes {
165165

166166
override def nullable: Boolean = left.nullable || right.nullable
167167

168-
override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
168+
override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
169169

170170
override def eval(input: InternalRow): Any = {
171171
val leftEval = left.eval(input)
@@ -238,7 +238,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
238238
if (str.dataType == BinaryType) str.dataType else StringType
239239
}
240240

241-
override def expectedChildTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType)
241+
override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType)
242242

243243
override def children: Seq[Expression] = str :: pos :: len :: Nil
244244

@@ -297,7 +297,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
297297
*/
298298
case class StringLength(child: Expression) extends UnaryExpression with AutoCastInputTypes {
299299
override def dataType: DataType = IntegerType
300-
override def expectedChildTypes: Seq[DataType] = Seq(StringType)
300+
override def inputTypes: Seq[DataType] = Seq(StringType)
301301

302302
override def eval(input: InternalRow): Any = {
303303
val string = child.eval(input)

0 commit comments

Comments
 (0)