Skip to content
Closed
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
0b293db
[SPARK-29774][SQL] Date and Timestamp type +/- null should be null as…
yaooqinn Nov 6, 2019
f726297
Merge branch 'master' into SPARK-29774
yaooqinn Nov 27, 2019
e7225a3
regen golden file
yaooqinn Nov 27, 2019
b925517
null - dates
yaooqinn Nov 27, 2019
cd49411
Merge branch 'master' into SPARK-29774
yaooqinn Dec 2, 2019
57b13e9
support +/-
yaooqinn Dec 2, 2019
eab6a83
support ×/÷
yaooqinn Dec 2, 2019
e8b75ba
import
yaooqinn Dec 2, 2019
02b3738
childResolved required
yaooqinn Dec 2, 2019
e89d806
regen golden file
yaooqinn Dec 2, 2019
0694e07
update comments
yaooqinn Dec 2, 2019
0f5618b
fix tests
yaooqinn Dec 3, 2019
efab3ec
fix tests
yaooqinn Dec 3, 2019
1c27be1
refine case match pattern
yaooqinn Dec 3, 2019
5df6980
fix ut
yaooqinn Dec 3, 2019
9817d2d
hack assert Equal
yaooqinn Dec 3, 2019
9808b9c
regen g f
yaooqinn Dec 3, 2019
b190612
AnalysisTest
yaooqinn Dec 3, 2019
e544137
regen g f
yaooqinn Dec 3, 2019
83705fd
fix test
yaooqinn Dec 4, 2019
846802d
date add/sub only work for int/smallint/tinyint
yaooqinn Dec 4, 2019
4af7edb
regen g f
yaooqinn Dec 4, 2019
a67be30
refine
yaooqinn Dec 4, 2019
9a1affd
type coercion for subtract timestamp
yaooqinn Dec 4, 2019
ae70022
add and reorgnize tests in datetime.sql
yaooqinn Dec 4, 2019
571225b
DateExpressionsSuite
yaooqinn Dec 4, 2019
6052e5a
fix py
yaooqinn Dec 4, 2019
928fd86
fix py
yaooqinn Dec 4, 2019
254d2d2
Revert "fix py"
yaooqinn Dec 5, 2019
5dd632c
fix py
yaooqinn Dec 5, 2019
c84d46e
rm unresolved binary arithmetic
yaooqinn Dec 5, 2019
a44948e
import
yaooqinn Dec 5, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ class Analyzer(
ResolveLambdaVariables(conf) ::
ResolveTimeZone(conf) ::
ResolveRandomSeed ::
ResolveBinaryArithmetic(conf) ::
TypeCoercion.typeCoercionRules(conf) ++
extendedResolutionRules : _*),
Batch("PostgreSQL Dialect", Once, PostgreSQLDialect.postgreSQLDialectRules: _*),
Expand All @@ -246,6 +247,65 @@ class Analyzer(
CleanupAliases)
)

/**
* For [[UnresolvedAdd]]:
* 1. if both side are interval, turns it to [[Add]];
* 2. else if one side is interval, turns it to [[TimeAdd]];
* 3. else if one side is date, turns it to [[DateAdd]] ;
* 4. else turns it to [[Add]].
*
* For [[UnresolvedSubtract]]:
* 1. if both side are interval, turns it to [[Subtract]];
* 2. else if the right side is an interval, turns it to [[TimeSub]];
* 3. else if one side is timestamp, turns it to [[SubtractTimestamps]];
* 4. else if the right side is date, turns it to [[DateDiff]]/[[SubtractDates]];
* 5. else if the left side is date, turns it to [[DateSub]];
* 6. else turns it to [[Subtract]].
*
* For [[UnresolvedMultiply]]:
* 1. If one side is interval, turns it to [[MultiplyInterval]];
* 2. otherwise, turns it to [[Multiply]].
*
* For [[UnresolvedDivide]]:
* 1. If the left side is interval, turns it to [[DivideInterval]];
* 2. otherwise, turns it to [[Divide]].
*/
case class ResolveBinaryArithmetic(conf: SQLConf) extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case p: LogicalPlan => p.transformExpressionsUp {
case u @ UnresolvedAdd(l, r) if u.childrenResolved => (l.dataType, r.dataType) match {
case (CalendarIntervalType, CalendarIntervalType) => Add(l, r)
case (_, CalendarIntervalType) => Cast(TimeAdd(l, r), l.dataType)
case (CalendarIntervalType, _) => Cast(TimeAdd(r, l), r.dataType)
case (DateType, _) => DateAdd(l, r)
case (_, DateType) => DateAdd(r, l)
case _ => Add(l, r)
}
case u @ UnresolvedSubtract(l, r) if u.childrenResolved => (l.dataType, r.dataType) match {
case (CalendarIntervalType, CalendarIntervalType) => Subtract(l, r)
case (_, CalendarIntervalType) => Cast(TimeSub(l, r), l.dataType)
case (TimestampType, _) => SubtractTimestamps(l, r)
case (_, TimestampType) => SubtractTimestamps(l, r)
case (_, DateType) => if (conf.usePostgreSQLDialect) {
DateDiff(l, r)
} else {
SubtractDates(l, r)
}
case (DateType, _) => DateSub(l, r)
case _ => Subtract(l, r)
}
case u @ UnresolvedMultiply(l, r) if u.childrenResolved => (l.dataType, r.dataType) match {
case (CalendarIntervalType, _) => MultiplyInterval(l, r)
case (_, CalendarIntervalType) => MultiplyInterval(r, l)
case _ => Multiply(l, r)
}
case u @ UnresolvedDivide(l, r) if u.childrenResolved => (l.dataType, r.dataType) match {
case (CalendarIntervalType, _) => DivideInterval(l, r)
case _ => Divide(l, r)
}
}
}
}
/**
* Substitute child plan with WindowSpecDefinitions.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -822,52 +822,24 @@ object TypeCoercion {
}
}

/**
* 1. Turns Add/Subtract of DateType/TimestampType/StringType and CalendarIntervalType
* to TimeAdd/TimeSub.
* 2. Turns Add/Subtract of TimestampType/DateType/IntegerType
* and TimestampType/IntegerType/DateType to DateAdd/DateSub/SubtractDates and
* to SubtractTimestamps.
* 3. Turns Multiply/Divide of CalendarIntervalType and NumericType
* to MultiplyInterval/DivideInterval
*/
object DateTimeOperations extends Rule[LogicalPlan] {

private val acceptedTypes = Seq(DateType, TimestampType, StringType)

def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e

case Add(l @ CalendarIntervalType(), r) if acceptedTypes.contains(r.dataType) =>
Cast(TimeAdd(r, l), r.dataType)
case Add(l, r @ CalendarIntervalType()) if acceptedTypes.contains(l.dataType) =>
Cast(TimeAdd(l, r), l.dataType)
case Subtract(l, r @ CalendarIntervalType()) if acceptedTypes.contains(l.dataType) =>
Cast(TimeSub(l, r), l.dataType)
case Multiply(l @ CalendarIntervalType(), r @ NumericType()) =>
MultiplyInterval(l, r)
case Multiply(l @ NumericType(), r @ CalendarIntervalType()) =>
MultiplyInterval(r, l)
case Divide(l @ CalendarIntervalType(), r @ NumericType()) =>
DivideInterval(l, r)

case b @ BinaryOperator(l @ CalendarIntervalType(), r @ NullType()) =>
b.withNewChildren(Seq(l, Cast(r, CalendarIntervalType)))
case b @ BinaryOperator(l @ NullType(), r @ CalendarIntervalType()) =>
b.withNewChildren(Seq(Cast(l, CalendarIntervalType), r))

case Add(l @ DateType(), r @ IntegerType()) => DateAdd(l, r)
case Add(l @ IntegerType(), r @ DateType()) => DateAdd(r, l)
case Subtract(l @ DateType(), r @ IntegerType()) => DateSub(l, r)
case Subtract(l @ DateType(), r @ DateType()) =>
if (SQLConf.get.usePostgreSQLDialect) DateDiff(l, r) else SubtractDates(l, r)
case Subtract(l @ TimestampType(), r @ TimestampType()) =>
SubtractTimestamps(l, r)
case Subtract(l @ TimestampType(), r @ DateType()) =>
SubtractTimestamps(l, Cast(r, TimestampType))
case Subtract(l @ DateType(), r @ TimestampType()) =>
SubtractTimestamps(Cast(l, TimestampType), r)
case d @ DateAdd(TimestampType(), _) => d.copy(startDate = Cast(d.startDate, DateType))
case d @ DateAdd(StringType(), _) => d.copy(startDate = Cast(d.startDate, DateType))
case d @ DateSub(TimestampType(), _) => d.copy(startDate = Cast(d.startDate, DateType))
case d @ DateSub(StringType(), _) => d.copy(startDate = Cast(d.startDate, DateType))

case s @ SubtractTimestamps(DateType(), _) =>
s.copy(endTimestamp = Cast(s.endTimestamp, TimestampType))
case s @ SubtractTimestamps(_, DateType()) =>
s.copy(startTimestamp = Cast(s.startTimestamp, TimestampType))

case t @ TimeAdd(DateType(), _, _) => t.copy(start = Cast(t.start, TimestampType))
case t @ TimeAdd(StringType(), _, _) => t.copy(start = Cast(t.start, TimestampType))
case t @ TimeSub(DateType(), _, _) => t.copy(start = Cast(t.start, TimestampType))
case t @ TimeSub(StringType(), _, _) => t.copy(start = Cast(t.start, TimestampType))
}
}

Expand All @@ -881,11 +853,8 @@ object TypeCoercion {
case e if !e.childrenResolved => e

// If DecimalType operands are involved, DecimalPrecision will handle it
// If CalendarIntervalType operands are involved, DateTimeOperations will handle it
case b @ BinaryOperator(left, right) if !left.dataType.isInstanceOf[DecimalType] &&
!right.dataType.isInstanceOf[DecimalType] &&
!left.dataType.isInstanceOf[CalendarIntervalType] &&
!right.dataType.isInstanceOf[CalendarIntervalType] &&
left.dataType != right.dataType =>
findTightestCommonType(left.dataType, right.dataType).map { commonType =>
if (b.inputType.acceptsType(commonType)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -544,3 +544,19 @@ case class UnresolvedOrdinal(ordinal: Int)
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override lazy val resolved = false
}

trait UnresolvedBinaryExpression extends BinaryExpression with Unevaluable {
override lazy val resolved: Boolean = false
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
}

case class UnresolvedAdd(left: Expression, right: Expression) extends UnresolvedBinaryExpression

case class UnresolvedSubtract(left: Expression, right: Expression)
extends UnresolvedBinaryExpression

case class UnresolvedMultiply(left: Expression, right: Expression)
extends UnresolvedBinaryExpression

case class UnresolvedDivide(left: Expression, right: Expression)
extends UnresolvedBinaryExpression
Original file line number Diff line number Diff line change
Expand Up @@ -151,17 +151,18 @@ case class CurrentBatchTimestamp(
""",
since = "1.5.0")
case class DateAdd(startDate: Expression, days: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
extends BinaryExpression with ExpectsInputTypes {

override def left: Expression = startDate
override def right: Expression = days

override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType)
override def inputTypes: Seq[AbstractDataType] =
Seq(DateType, TypeCollection(IntegerType, ShortType, ByteType))

override def dataType: DataType = DateType

override def nullSafeEval(start: Any, d: Any): Any = {
start.asInstanceOf[Int] + d.asInstanceOf[Int]
start.asInstanceOf[Int] + d.asInstanceOf[Number].intValue()
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand All @@ -185,16 +186,17 @@ case class DateAdd(startDate: Expression, days: Expression)
""",
since = "1.5.0")
case class DateSub(startDate: Expression, days: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
extends BinaryExpression with ExpectsInputTypes {
override def left: Expression = startDate
override def right: Expression = days

override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType)
override def inputTypes: Seq[AbstractDataType] =
Seq(DateType, TypeCollection(IntegerType, ShortType, ByteType))

override def dataType: DataType = DateType

override def nullSafeEval(start: Any, d: Any): Any = {
start.asInstanceOf[Int] - d.asInstanceOf[Int]
start.asInstanceOf[Int] - d.asInstanceOf[Number].intValue()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add some UT in DateExpressionsSuite to make sure byte/short works?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand Down Expand Up @@ -1072,7 +1074,7 @@ case class NextDay(startDate: Expression, dayOfWeek: Expression)
* Adds an interval to timestamp.
*/
case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[String] = None)
extends BinaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes {
extends BinaryExpression with TimeZoneAwareExpression with ExpectsInputTypes {

def this(start: Expression, interval: Expression) = this(start, interval, None)

Expand Down Expand Up @@ -1187,7 +1189,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression)
* Subtracts an interval from timestamp.
*/
case class TimeSub(start: Expression, interval: Expression, timeZoneId: Option[String] = None)
extends BinaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes {
extends BinaryExpression with TimeZoneAwareExpression with ExpectsInputTypes {

def this(start: Expression, interval: Expression) = this(start, interval, None)

Expand Down Expand Up @@ -2127,7 +2129,7 @@ case class DatePart(field: Expression, source: Expression, child: Expression)
* between the given timestamps.
*/
case class SubtractTimestamps(endTimestamp: Expression, startTimestamp: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
extends BinaryExpression with ExpectsInputTypes {

override def left: Expression = endTimestamp
override def right: Expression = startTimestamp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1429,17 +1429,17 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
val right = expression(ctx.right)
ctx.operator.getType match {
case SqlBaseParser.ASTERISK =>
Multiply(left, right)
UnresolvedMultiply(left, right)
case SqlBaseParser.SLASH =>
Divide(left, right)
UnresolvedDivide(left, right)
case SqlBaseParser.PERCENT =>
Remainder(left, right)
case SqlBaseParser.DIV =>
IntegralDivide(left, right)
case SqlBaseParser.PLUS =>
Add(left, right)
UnresolvedAdd(left, right)
case SqlBaseParser.MINUS =>
Subtract(left, right)
UnresolvedSubtract(left, right)
case SqlBaseParser.CONCAT_PIPE =>
Concat(left :: right :: Nil)
case SqlBaseParser.AMPERSAND =>
Expand Down Expand Up @@ -1696,9 +1696,12 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
*/
override def visitFrameBound(ctx: FrameBoundContext): Expression = withOrigin(ctx) {
def value: Expression = {
val e = expression(ctx.expression)
validate(e.resolved && e.foldable, "Frame bound value must be a literal.", ctx)
e
expression(ctx.expression) match {
case u: UnresolvedBinaryExpression if u.childrenResolved && u.foldable => u
case e =>
validate(e.resolved && e.foldable, "Frame bound value must be a literal.", ctx)
e
}
}

ctx.boundType.getType match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1401,44 +1401,6 @@ class TypeCoercionSuite extends AnalysisTest {
}
}

test("rule for date/timestamp operations") {
val dateTimeOperations = TypeCoercion.DateTimeOperations
val date = Literal(new java.sql.Date(0L))
val timestamp = Literal(new Timestamp(0L))
val interval = Literal(new CalendarInterval(0, 0, 0))
val str = Literal("2015-01-01")
val intValue = Literal(0, IntegerType)

ruleTest(dateTimeOperations, Add(date, interval), Cast(TimeAdd(date, interval), DateType))
ruleTest(dateTimeOperations, Add(interval, date), Cast(TimeAdd(date, interval), DateType))
ruleTest(dateTimeOperations, Add(timestamp, interval),
Cast(TimeAdd(timestamp, interval), TimestampType))
ruleTest(dateTimeOperations, Add(interval, timestamp),
Cast(TimeAdd(timestamp, interval), TimestampType))
ruleTest(dateTimeOperations, Add(str, interval), Cast(TimeAdd(str, interval), StringType))
ruleTest(dateTimeOperations, Add(interval, str), Cast(TimeAdd(str, interval), StringType))

ruleTest(dateTimeOperations, Subtract(date, interval), Cast(TimeSub(date, interval), DateType))
ruleTest(dateTimeOperations, Subtract(timestamp, interval),
Cast(TimeSub(timestamp, interval), TimestampType))
ruleTest(dateTimeOperations, Subtract(str, interval), Cast(TimeSub(str, interval), StringType))

// interval operations should not be effected
ruleTest(dateTimeOperations, Add(interval, interval), Add(interval, interval))
ruleTest(dateTimeOperations, Subtract(interval, interval), Subtract(interval, interval))

ruleTest(dateTimeOperations, Add(date, intValue), DateAdd(date, intValue))
ruleTest(dateTimeOperations, Add(intValue, date), DateAdd(date, intValue))
ruleTest(dateTimeOperations, Subtract(date, intValue), DateSub(date, intValue))
ruleTest(dateTimeOperations, Subtract(date, date), SubtractDates(date, date))
ruleTest(dateTimeOperations, Subtract(timestamp, timestamp),
SubtractTimestamps(timestamp, timestamp))
ruleTest(dateTimeOperations, Subtract(timestamp, date),
SubtractTimestamps(timestamp, Cast(date, TimestampType)))
ruleTest(dateTimeOperations, Subtract(date, timestamp),
SubtractTimestamps(Cast(date, TimestampType), timestamp))
}

/**
* There are rules that need to not fire before child expressions get resolved.
* We use this test to make sure those rules do not fire early.
Expand Down Expand Up @@ -1586,27 +1548,6 @@ class TypeCoercionSuite extends AnalysisTest {
Multiply(CaseWhen(Seq((EqualTo(1, 2), Cast(1, DecimalType(34, 24)))),
Cast(100, DecimalType(34, 24))), Cast(1, IntegerType)))
}

test("rule for interval operations") {
val dateTimeOperations = TypeCoercion.DateTimeOperations
val interval = Literal(new CalendarInterval(0, 0, 0))

Seq(
Literal(10.toByte, ByteType),
Literal(10.toShort, ShortType),
Literal(10, IntegerType),
Literal(10L, LongType),
Literal(Decimal(10), DecimalType.SYSTEM_DEFAULT),
Literal(10.5.toFloat, FloatType),
Literal(10.5, DoubleType)).foreach { num =>
ruleTest(dateTimeOperations, Multiply(interval, num),
MultiplyInterval(interval, num))
ruleTest(dateTimeOperations, Multiply(num, interval),
MultiplyInterval(interval, num))
ruleTest(dateTimeOperations, Divide(interval, num),
DivideInterval(interval, num))
}
}
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
*/
package org.apache.spark.sql.catalyst.parser

import org.apache.spark.sql.catalyst.analysis.AnalysisTest
import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAdd, UnresolvedDivide, UnresolvedMultiply, UnresolvedSubtract}
import org.apache.spark.sql.catalyst.expressions.{Add, Divide, Multiply, Subtract}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

/**
Expand All @@ -28,7 +29,16 @@ class ErrorParserSuite extends AnalysisTest {
import org.apache.spark.sql.catalyst.dsl.plans._

private def assertEqual(sqlCommand: String, plan: LogicalPlan): Unit = {
assert(parsePlan(sqlCommand) == plan)
val resolvedPlan = parsePlan(sqlCommand) resolveOperatorsUp {
case p: LogicalPlan => p transformAllExpressions {
case UnresolvedAdd(l, r) => Add(l, r)
case UnresolvedSubtract(l, r) => Subtract(l, r)
case UnresolvedMultiply(l, r) => Multiply(l, r)
case UnresolvedDivide(l, r) => Divide(l, r)
case other => other
}
}
assert(resolvedPlan == plan)
}

def intercept(sqlCommand: String, messages: String*): Unit =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,15 @@ class ExpressionParserSuite extends AnalysisTest {
sqlCommand: String,
e: Expression,
parser: ParserInterface = defaultParser): Unit = {
compareExpressions(parser.parseExpression(sqlCommand), e)
// usage for UnresolvedAdd etc here is just for tests
val expression = parser.parseExpression(sqlCommand) transform {
case UnresolvedAdd(l, r) => Add(l, r)
case UnresolvedSubtract(l, r) => Subtract(l, r)
case UnresolvedMultiply(l, r) => Multiply(l, r)
case UnresolvedDivide(l, r) => Divide(l, r)
case other => other
}
compareExpressions(expression, e)
}

private def intercept(sqlCommand: String, messages: String*): Unit =
Expand Down
Loading