Skip to content

Commit 0961cac

Browse files
committed
Merge branch 'SPARK-21774' into 'spark_2.1'
[SPARK-21774] 数字类型和字符串比较的时候都统一转成double类型进行比较 现在字符串和数值的比较都是把字符串转成跟数值一样的数据格式之后再去比较 测试case: `select "1.1" = 1;` `"1.1" = 1`这样的判断,如果是把1.1转成int类型之后就是1了,它就和1相等了... resolve apache#110 See merge request !67
2 parents f0eb740 + bd0ac93 commit 0961cac

File tree

3 files changed

+55
-29
lines changed

3 files changed

+55
-29
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -294,10 +294,42 @@ object TypeCoercion {
294294
}
295295
}
296296

297+
/**
298+
* This function determines the target type of a comparison operator when one operand
299+
* is a String and the other is not. It also handles when one op is a Date and the
300+
* other is a Timestamp by making the target type to be String.
301+
*/
302+
val findCommonTypeForBinaryComparison: (DataType, DataType) => Option[DataType] = {
303+
// We should cast all relative timestamp/date/string comparison into string comparisons
304+
// This behaves as a user would expect because timestamp strings sort lexicographically.
305+
// i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true
306+
case (StringType, DateType) => Some(StringType)
307+
case (DateType, StringType) => Some(StringType)
308+
case (StringType, TimestampType) => Some(StringType)
309+
case (TimestampType, StringType) => Some(StringType)
310+
case (TimestampType, DateType) => Some(StringType)
311+
case (DateType, TimestampType) => Some(StringType)
312+
case (StringType, NullType) => Some(StringType)
313+
case (NullType, StringType) => Some(StringType)
314+
case (StringType, r: NumericType) => Some(DoubleType)
315+
case (l: NumericType, StringType) => Some(DoubleType)
316+
case (l: StringType, r: AtomicType) if r != StringType => Some(r)
317+
case (l: AtomicType, r: StringType) if l != StringType => Some(l)
318+
case (l, r) => None
319+
}
320+
297321
/**
298322
* Promotes strings that appear in arithmetic expressions.
299323
*/
300324
object PromoteStrings extends Rule[LogicalPlan] {
325+
private def castExpr(expr: Expression, targetType: DataType): Expression = {
326+
(expr.dataType, targetType) match {
327+
case (NullType, dt) => Literal.create(null, targetType)
328+
case (l, dt) if (l != dt) => Cast(expr, targetType)
329+
case _ => expr
330+
}
331+
}
332+
301333
def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
302334
// Skip nodes who's children have not been resolved yet.
303335
case e if !e.childrenResolved => e
@@ -314,34 +346,10 @@ object TypeCoercion {
314346
case p @ Equality(left @ TimestampType(), right @ StringType()) =>
315347
p.makeCopy(Array(left, Cast(right, TimestampType)))
316348

317-
// We should cast all relative timestamp/date/string comparison into string comparisons
318-
// This behaves as a user would expect because timestamp strings sort lexicographically.
319-
// i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true
320-
case p @ BinaryComparison(left @ StringType(), right @ DateType()) =>
321-
p.makeCopy(Array(left, Cast(right, StringType)))
322-
case p @ BinaryComparison(left @ DateType(), right @ StringType()) =>
323-
p.makeCopy(Array(Cast(left, StringType), right))
324-
case p @ BinaryComparison(left @ StringType(), right @ TimestampType()) =>
325-
p.makeCopy(Array(left, Cast(right, StringType)))
326-
case p @ BinaryComparison(left @ TimestampType(), right @ StringType()) =>
327-
p.makeCopy(Array(Cast(left, StringType), right))
328-
329-
// Comparisons between dates and timestamps.
330-
case p @ BinaryComparison(left @ TimestampType(), right @ DateType()) =>
331-
p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType)))
332-
case p @ BinaryComparison(left @ DateType(), right @ TimestampType()) =>
333-
p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType)))
334-
335-
// Checking NullType
336-
case p @ BinaryComparison(left @ StringType(), right @ NullType()) =>
337-
p.makeCopy(Array(left, Literal.create(null, StringType)))
338-
case p @ BinaryComparison(left @ NullType(), right @ StringType()) =>
339-
p.makeCopy(Array(Literal.create(null, StringType), right))
340-
341-
case p @ BinaryComparison(left @ StringType(), right) if right.dataType != StringType =>
342-
p.makeCopy(Array(Cast(left, right.dataType), right))
343-
case p @ BinaryComparison(left, right @ StringType()) if left.dataType != StringType =>
344-
p.makeCopy(Array(left, Cast(right, left.dataType)))
349+
case p @ BinaryComparison(left, right)
350+
if findCommonTypeForBinaryComparison(left.dataType, right.dataType).isDefined =>
351+
val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType).get
352+
p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType)))
345353

346354
case i @ In(a @ DateType(), b) if b.forall(_.dataType == StringType) =>
347355
i.makeCopy(Array(Cast(a, StringType), b))
@@ -356,6 +364,8 @@ object TypeCoercion {
356364
case Average(e @ StringType()) => Average(Cast(e, DoubleType))
357365
case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType))
358366
case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType))
367+
case UnaryMinus(e @ StringType()) => UnaryMinus(Cast(e, DoubleType))
368+
case UnaryPositive(e @ StringType()) => UnaryPositive(Cast(e, DoubleType))
359369
case VariancePop(e @ StringType()) => VariancePop(Cast(e, DoubleType))
360370
case VarianceSamp(e @ StringType()) => VarianceSamp(Cast(e, DoubleType))
361371
case Skewness(e @ StringType()) => Skewness(Cast(e, DoubleType))

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -891,7 +891,13 @@ class TypeCoercionSuite extends PlanTest {
891891
test("binary comparison with string promotion") {
892892
ruleTest(PromoteStrings,
893893
GreaterThan(Literal("123"), Literal(1)),
894-
GreaterThan(Cast(Literal("123"), IntegerType), Literal(1)))
894+
GreaterThan(Cast(Literal("123"), DoubleType), Cast(Literal(1), DoubleType)))
895+
ruleTest(PromoteStrings,
896+
GreaterThan(Literal("123"), Literal(1L)),
897+
GreaterThan(Cast(Literal("123"), DoubleType), Cast(Literal(1L), DoubleType)))
898+
ruleTest(PromoteStrings,
899+
GreaterThan(Literal("123"), Literal(0.1)),
900+
GreaterThan(Cast(Literal("123"), DoubleType), Literal(0.1)))
895901
ruleTest(PromoteStrings,
896902
LessThan(Literal(true), Literal("123")),
897903
LessThan(Literal(true), Cast(Literal("123"), BooleanType)))

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2503,4 +2503,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
25032503
case ae: AnalysisException => assert(ae.plan == null && ae.getMessage == ae.getSimpleMessage)
25042504
}
25052505
}
2506+
2507+
test("SPARK-21774: should cast a string to double type when compare with a int") {
2508+
withTempView("src") {
2509+
Seq(("0", 1), ("-0.4", 2)).toDF("a", "b").createOrReplaceTempView("src")
2510+
checkAnswer(sql("SELECT a FROM src WHERE a=0"), Seq(Row("0")))
2511+
checkAnswer(sql("SELECT a FROM src WHERE a=0L"), Seq(Row("0")))
2512+
checkAnswer(sql("SELECT a FROM src WHERE a=0.0"), Seq(Row("0")))
2513+
checkAnswer(sql("SELECT a FROM src WHERE a=-0.4"), Seq(Row("-0.4")))
2514+
}
2515+
}
25062516
}

0 commit comments

Comments
 (0)