diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md
index b2bd8cefc3f96..a643a843a5cb4 100644
--- a/docs/sql-migration-guide-upgrade.md
+++ b/docs/sql-migration-guide-upgrade.md
@@ -161,6 +161,95 @@ license: |
- Since Spark 3.0, Dataset query fails if it contains ambiguous column reference that is caused by self join. A typical example: `val df1 = ...; val df2 = df1.filter(...);`, then `df1.join(df2, df1("a") > df2("a"))` returns an empty result which is quite confusing. This is because Spark cannot resolve Dataset column references that point to tables being self joined, and `df1("a")` is exactly the same as `df2("a")` in Spark. To restore the behavior before Spark 3.0, you can set `spark.sql.analyzer.failAmbiguousSelfJoin` to `false`.
+ - Since Spark 3.0, `Cast` function processes string literals such as 'Infinity', '+Infinity', '-Infinity', 'NaN', 'Inf', '+Inf', '-Inf' in case insensitive manner when casting the literals to `Double` or `Float` type to ensure greater compatibility with other database systems. This behaviour change is illustrated in the table below:
+
+
+ |
+ Operation
+ |
+
+ Result prior to Spark 3.0
+ |
+
+ Result starting Spark 3.0
+ |
+
+
+
+ CAST('infinity' AS DOUBLE)
+ CAST('+infinity' AS DOUBLE)
+ CAST('inf' AS DOUBLE)
+ CAST('+inf' AS DOUBLE)
+ |
+
+ NULL
+ |
+
+ Double.PositiveInfinity
+ |
+
+
+
+ CAST('-infinity' AS DOUBLE)
+ CAST('-inf' AS DOUBLE)
+ |
+
+ NULL
+ |
+
+ Double.NegativeInfinity
+ |
+
+
+
+ CAST('infinity' AS FLOAT)
+ CAST('+infinity' AS FLOAT)
+ CAST('inf' AS FLOAT)
+ CAST('+inf' AS FLOAT)
+ |
+
+ NULL
+ |
+
+ Float.PositiveInfinity
+ |
+
+
+
+ CAST('-infinity' AS FLOAT)
+ CAST('-inf' AS FLOAT)
+ |
+
+ NULL
+ |
+
+ Float.NegativeInfinity
+ |
+
+
+ |
+ CAST('nan' AS DOUBLE)
+ |
+
+ NULL
+ |
+
+ Double.NaN
+ |
+
+
+ |
+ CAST('nan' AS FLOAT)
+ |
+
+ NULL
+ |
+
+ Float.NaN
+ |
+
+
+
## Upgrading from Spark SQL 2.4 to 2.4.1
- The value of `spark.executor.heartbeatInterval`, when specified without units like "30" rather than "30s", was
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 10464dac8d55e..7ba0910ac2157 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import java.math.{BigDecimal => JavaBigDecimal}
import java.time.ZoneId
+import java.util.Locale
import java.util.concurrent.TimeUnit._
import org.apache.spark.SparkException
@@ -192,6 +193,22 @@ object Cast {
}
def resolvableNullability(from: Boolean, to: Boolean): Boolean = !from || to
+
+ /**
+ * We process literals such as 'Infinity', 'Inf', '-Infinity' and 'NaN' etc in case
+ * insensitive manner to be compatible with other database systems such as PostgreSQL and DB2.
+ */
+ def processFloatingPointSpecialLiterals(v: String, isFloat: Boolean): Any = {
+ v.trim.toLowerCase(Locale.ROOT) match {
+ case "inf" | "+inf" | "infinity" | "+infinity" =>
+ if (isFloat) Float.PositiveInfinity else Double.PositiveInfinity
+ case "-inf" | "-infinity" =>
+ if (isFloat) Float.NegativeInfinity else Double.NegativeInfinity
+ case "nan" =>
+ if (isFloat) Float.NaN else Double.NaN
+ case _ => null
+ }
+ }
}
/**
@@ -562,8 +579,12 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
// DoubleConverter
private[this] def castToDouble(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[UTF8String](_, s => try s.toString.toDouble catch {
- case _: NumberFormatException => null
+ buildCast[UTF8String](_, s => {
+ val doubleStr = s.toString
+ try doubleStr.toDouble catch {
+ case _: NumberFormatException =>
+ Cast.processFloatingPointSpecialLiterals(doubleStr, false)
+ }
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1d else 0d)
@@ -578,8 +599,12 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
// FloatConverter
private[this] def castToFloat(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[UTF8String](_, s => try s.toString.toFloat catch {
- case _: NumberFormatException => null
+ buildCast[UTF8String](_, s => {
+ val floatStr = s.toString
+ try floatStr.toFloat catch {
+ case _: NumberFormatException =>
+ Cast.processFloatingPointSpecialLiterals(floatStr, true)
+ }
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1f else 0f)
@@ -717,9 +742,9 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
case ByteType => castToByteCode(from, ctx)
case ShortType => castToShortCode(from, ctx)
case IntegerType => castToIntCode(from, ctx)
- case FloatType => castToFloatCode(from)
+ case FloatType => castToFloatCode(from, ctx)
case LongType => castToLongCode(from, ctx)
- case DoubleType => castToDoubleCode(from)
+ case DoubleType => castToDoubleCode(from, ctx)
case array: ArrayType =>
castArrayCode(from.asInstanceOf[ArrayType].elementType, array.elementType, ctx)
@@ -1259,48 +1284,66 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) => code"$evPrim = (long) $c;"
}
- private[this] def castToFloatCode(from: DataType): CastFunction = from match {
- case StringType =>
- (c, evPrim, evNull) =>
- code"""
+ private[this] def castToFloatCode(from: DataType, ctx: CodegenContext): CastFunction = {
+ from match {
+ case StringType =>
+ val floatStr = ctx.freshVariable("floatStr", StringType)
+ (c, evPrim, evNull) =>
+ code"""
+ final String $floatStr = $c.toString();
try {
- $evPrim = Float.valueOf($c.toString());
+ $evPrim = Float.valueOf($floatStr);
} catch (java.lang.NumberFormatException e) {
- $evNull = true;
+ final Float f = (Float) Cast.processFloatingPointSpecialLiterals($floatStr, true);
+ if (f == null) {
+ $evNull = true;
+ } else {
+ $evPrim = f.floatValue();
+ }
}
"""
- case BooleanType =>
- (c, evPrim, evNull) => code"$evPrim = $c ? 1.0f : 0.0f;"
- case DateType =>
- (c, evPrim, evNull) => code"$evNull = true;"
- case TimestampType =>
- (c, evPrim, evNull) => code"$evPrim = (float) (${timestampToDoubleCode(c)});"
- case DecimalType() =>
- (c, evPrim, evNull) => code"$evPrim = $c.toFloat();"
- case x: NumericType =>
- (c, evPrim, evNull) => code"$evPrim = (float) $c;"
+ case BooleanType =>
+ (c, evPrim, evNull) => code"$evPrim = $c ? 1.0f : 0.0f;"
+ case DateType =>
+ (c, evPrim, evNull) => code"$evNull = true;"
+ case TimestampType =>
+ (c, evPrim, evNull) => code"$evPrim = (float) (${timestampToDoubleCode(c)});"
+ case DecimalType() =>
+ (c, evPrim, evNull) => code"$evPrim = $c.toFloat();"
+ case x: NumericType =>
+ (c, evPrim, evNull) => code"$evPrim = (float) $c;"
+ }
}
- private[this] def castToDoubleCode(from: DataType): CastFunction = from match {
- case StringType =>
- (c, evPrim, evNull) =>
- code"""
+ private[this] def castToDoubleCode(from: DataType, ctx: CodegenContext): CastFunction = {
+ from match {
+ case StringType =>
+ val doubleStr = ctx.freshVariable("doubleStr", StringType)
+ (c, evPrim, evNull) =>
+ code"""
+ final String $doubleStr = $c.toString();
try {
- $evPrim = Double.valueOf($c.toString());
+ $evPrim = Double.valueOf($doubleStr);
} catch (java.lang.NumberFormatException e) {
- $evNull = true;
+ final Double d = (Double) Cast.processFloatingPointSpecialLiterals($doubleStr, false);
+ if (d == null) {
+ $evNull = true;
+ } else {
+ $evPrim = d.doubleValue();
+ }
}
"""
- case BooleanType =>
- (c, evPrim, evNull) => code"$evPrim = $c ? 1.0d : 0.0d;"
- case DateType =>
- (c, evPrim, evNull) => code"$evNull = true;"
- case TimestampType =>
- (c, evPrim, evNull) => code"$evPrim = ${timestampToDoubleCode(c)};"
- case DecimalType() =>
- (c, evPrim, evNull) => code"$evPrim = $c.toDouble();"
- case x: NumericType =>
- (c, evPrim, evNull) => code"$evPrim = (double) $c;"
+ case BooleanType =>
+ (c, evPrim, evNull) => code"$evPrim = $c ? 1.0d : 0.0d;"
+ case DateType =>
+ (c, evPrim, evNull) => code"$evNull = true;"
+ case TimestampType =>
+ (c, evPrim, evNull) => code"$evPrim = ${timestampToDoubleCode(c)};"
+ case DecimalType() =>
+ (c, evPrim, evNull) => code"$evPrim = $c.toDouble();"
+ case x: NumericType =>
+ (c, evPrim, evNull) => code"$evPrim = (double) $c;"
+ }
}
private[this] def castArrayCode(
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index bbb3cb516b7d5..861bfc92bbe66 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -1045,4 +1045,30 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
Cast(Literal(134.12), DecimalType(3, 2)), "cannot be represented")
}
}
+
+ test("Process Infinity, -Infinity, NaN in case insensitive manner") {
+ Seq("inf", "+inf", "infinity", "+infiNity", " infinity ").foreach { value =>
+ checkEvaluation(cast(value, FloatType), Float.PositiveInfinity)
+ }
+ Seq("-infinity", "-infiniTy", " -infinity ", " -inf ").foreach { value =>
+ checkEvaluation(cast(value, FloatType), Float.NegativeInfinity)
+ }
+ Seq("inf", "+inf", "infinity", "+infiNity", " infinity ").foreach { value =>
+ checkEvaluation(cast(value, DoubleType), Double.PositiveInfinity)
+ }
+ Seq("-infinity", "-infiniTy", " -infinity ", " -inf ").foreach { value =>
+ checkEvaluation(cast(value, DoubleType), Double.NegativeInfinity)
+ }
+ Seq("nan", "nAn", " nan ").foreach { value =>
+ checkEvaluation(cast(value, FloatType), Float.NaN)
+ }
+ Seq("nan", "nAn", " nan ").foreach { value =>
+ checkEvaluation(cast(value, DoubleType), Double.NaN)
+ }
+
+ // Invalid literals when casted to double and float results in null.
+ Seq(DoubleType, FloatType).foreach { dataType =>
+ checkEvaluation(cast("badvalue", dataType), null)
+ }
+ }
}
diff --git a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/aggregates_part1.sql b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/aggregates_part1.sql
index 801a16cf41f54..5d54be9341148 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/aggregates_part1.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/aggregates_part1.sql
@@ -59,16 +59,14 @@ select avg(CAST(null AS DOUBLE)) from range(1,4);
select sum(CAST('NaN' AS DOUBLE)) from range(1,4);
select avg(CAST('NaN' AS DOUBLE)) from range(1,4);
--- [SPARK-27768] verify correct results for infinite inputs
SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE))
-FROM (VALUES (CAST('1' AS DOUBLE)), (CAST('Infinity' AS DOUBLE))) v(x);
+FROM (VALUES (CAST('1' AS DOUBLE)), (CAST('infinity' AS DOUBLE))) v(x);
SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE))
-FROM (VALUES ('Infinity'), ('1')) v(x);
+FROM (VALUES ('infinity'), ('1')) v(x);
SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE))
-FROM (VALUES ('Infinity'), ('Infinity')) v(x);
+FROM (VALUES ('infinity'), ('infinity')) v(x);
SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE))
-FROM (VALUES ('-Infinity'), ('Infinity')) v(x);
-
+FROM (VALUES ('-infinity'), ('infinity')) v(x);
-- test accuracy with a large input offset
SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE))
diff --git a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/float4.sql b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/float4.sql
index 3dad5cd56ba02..058467695a608 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/float4.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/float4.sql
@@ -38,7 +38,6 @@ INSERT INTO FLOAT4_TBL VALUES ('1.2345678901234e-20');
-- special inputs
SELECT float('NaN');
--- [SPARK-28060] Float type can not accept some special inputs
SELECT float('nan');
SELECT float(' NAN ');
SELECT float('infinity');
@@ -49,7 +48,6 @@ SELECT float('N A N');
SELECT float('NaN x');
SELECT float(' INFINITY x');
--- [SPARK-28060] Float type can not accept some special inputs
SELECT float('Infinity') + 100.0;
SELECT float('Infinity') / float('Infinity');
SELECT float('nan') / float('nan');
diff --git a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/float8.sql b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/float8.sql
index 6f8e3b596e60e..957dabdebab4e 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/float8.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/float8.sql
@@ -37,7 +37,6 @@ SELECT double('-10e-400');
-- special inputs
SELECT double('NaN');
--- [SPARK-28060] Double type can not accept some special inputs
SELECT double('nan');
SELECT double(' NAN ');
SELECT double('infinity');
@@ -49,7 +48,6 @@ SELECT double('NaN x');
SELECT double(' INFINITY x');
SELECT double('Infinity') + 100.0;
--- [SPARK-27768] Infinity, -Infinity, NaN should be recognized in a case insensitive manner
SELECT double('Infinity') / double('Infinity');
SELECT double('NaN') / double('NaN');
-- [SPARK-28315] Decimal can not accept NaN as input
@@ -190,7 +188,7 @@ SELECT tanh(double('1'));
SELECT asinh(double('1'));
SELECT acosh(double('2'));
SELECT atanh(double('0.5'));
--- [SPARK-27768] Infinity, -Infinity, NaN should be recognized in a case insensitive manner
+
-- test Inf/NaN cases for hyperbolic functions
SELECT sinh(double('Infinity'));
SELECT sinh(double('-Infinity'));
diff --git a/sql/core/src/test/resources/sql-tests/results/pgSQL/aggregates_part1.sql.out b/sql/core/src/test/resources/sql-tests/results/pgSQL/aggregates_part1.sql.out
index 51ca1d558691c..29bafb42f579e 100644
--- a/sql/core/src/test/resources/sql-tests/results/pgSQL/aggregates_part1.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/pgSQL/aggregates_part1.sql.out
@@ -236,7 +236,7 @@ NaN
-- !query 29
SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE))
-FROM (VALUES (CAST('1' AS DOUBLE)), (CAST('Infinity' AS DOUBLE))) v(x)
+FROM (VALUES (CAST('1' AS DOUBLE)), (CAST('infinity' AS DOUBLE))) v(x)
-- !query 29 schema
struct
-- !query 29 output
@@ -245,7 +245,7 @@ Infinity NaN
-- !query 30
SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE))
-FROM (VALUES ('Infinity'), ('1')) v(x)
+FROM (VALUES ('infinity'), ('1')) v(x)
-- !query 30 schema
struct
-- !query 30 output
@@ -254,7 +254,7 @@ Infinity NaN
-- !query 31
SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE))
-FROM (VALUES ('Infinity'), ('Infinity')) v(x)
+FROM (VALUES ('infinity'), ('infinity')) v(x)
-- !query 31 schema
struct
-- !query 31 output
@@ -263,7 +263,7 @@ Infinity NaN
-- !query 32
SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE))
-FROM (VALUES ('-Infinity'), ('Infinity')) v(x)
+FROM (VALUES ('-infinity'), ('infinity')) v(x)
-- !query 32 schema
struct
-- !query 32 output
diff --git a/sql/core/src/test/resources/sql-tests/results/pgSQL/float4.sql.out b/sql/core/src/test/resources/sql-tests/results/pgSQL/float4.sql.out
index 86d88007d8892..6e47cff91a7d5 100644
--- a/sql/core/src/test/resources/sql-tests/results/pgSQL/float4.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/pgSQL/float4.sql.out
@@ -63,7 +63,7 @@ SELECT float('nan')
-- !query 7 schema
struct
-- !query 7 output
-NULL
+NaN
-- !query 8
@@ -71,7 +71,7 @@ SELECT float(' NAN ')
-- !query 8 schema
struct
-- !query 8 output
-NULL
+NaN
-- !query 9
@@ -79,7 +79,7 @@ SELECT float('infinity')
-- !query 9 schema
struct
-- !query 9 output
-NULL
+Infinity
-- !query 10
@@ -87,7 +87,7 @@ SELECT float(' -INFINiTY ')
-- !query 10 schema
struct
-- !query 10 output
-NULL
+-Infinity
-- !query 11
@@ -135,7 +135,7 @@ SELECT float('nan') / float('nan')
-- !query 16 schema
struct<(CAST(CAST(nan AS FLOAT) AS DOUBLE) / CAST(CAST(nan AS FLOAT) AS DOUBLE)):double>
-- !query 16 output
-NULL
+NaN
-- !query 17
diff --git a/sql/core/src/test/resources/sql-tests/results/pgSQL/float8.sql.out b/sql/core/src/test/resources/sql-tests/results/pgSQL/float8.sql.out
index eb9e8aa6361a1..b4ea3c1ad1cab 100644
--- a/sql/core/src/test/resources/sql-tests/results/pgSQL/float8.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/pgSQL/float8.sql.out
@@ -95,7 +95,7 @@ SELECT double('nan')
-- !query 11 schema
struct
-- !query 11 output
-NULL
+NaN
-- !query 12
@@ -103,7 +103,7 @@ SELECT double(' NAN ')
-- !query 12 schema
struct
-- !query 12 output
-NULL
+NaN
-- !query 13
@@ -111,7 +111,7 @@ SELECT double('infinity')
-- !query 13 schema
struct
-- !query 13 output
-NULL
+Infinity
-- !query 14
@@ -119,7 +119,7 @@ SELECT double(' -INFINiTY ')
-- !query 14 schema
struct
-- !query 14 output
-NULL
+-Infinity
-- !query 15