diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 8d11f4663a3e..88a2de414634 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -484,7 +484,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit private[this] def castToLong(from: DataType): Any => Any = from match { case StringType => val result = new LongWrapper() - buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else null) + buildCast[UTF8String](_, s => if (s.trim.toLong(result)) result.value else null) case BooleanType => buildCast[Boolean](_, b => if (b) 1L else 0L) case DateType => @@ -501,7 +501,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit private[this] def castToInt(from: DataType): Any => Any = from match { case StringType => val result = new IntWrapper() - buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else null) + buildCast[UTF8String](_, s => if (s.trim.toInt(result)) result.value else null) case BooleanType => buildCast[Boolean](_, b => if (b) 1 else 0) case DateType => @@ -520,7 +520,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit private[this] def castToShort(from: DataType): Any => Any = from match { case StringType => val result = new IntWrapper() - buildCast[UTF8String](_, s => if (s.toShort(result)) { + buildCast[UTF8String](_, s => if (s.trim.toShort(result)) { result.value.toShort } else { null @@ -561,7 +561,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit private[this] def castToByte(from: DataType): Any => Any = from match { case StringType => val result = new IntWrapper() - buildCast[UTF8String](_, s => if (s.toByte(result)) { + buildCast[UTF8String](_, s => if (s.trim.toByte(result)) { result.value.toByte } else { null @@ -632,7 +632,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { case StringType => buildCast[UTF8String](_, s => try { - changePrecision(Decimal(new JavaBigDecimal(s.toString)), target) + changePrecision(Decimal(new JavaBigDecimal(s.trim.toString)), target) } catch { case _: NumberFormatException => null }) @@ -659,7 +659,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit private[this] def castToDouble(from: DataType): Any => Any = from match { case StringType => buildCast[UTF8String](_, s => { - val doubleStr = s.toString + val doubleStr = s.trim.toString try doubleStr.toDouble catch { case _: NumberFormatException => Cast.processFloatingPointSpecialLiterals(doubleStr, false) @@ -679,7 +679,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit private[this] def castToFloat(from: DataType): Any => Any = from match { case StringType => buildCast[UTF8String](_, s => { - val floatStr = s.toString + val floatStr = s.trim.toString try floatStr.toFloat catch { case _: NumberFormatException => Cast.processFloatingPointSpecialLiterals(floatStr, true) diff --git a/sql/core/src/test/resources/sql-tests/inputs/cast.sql b/sql/core/src/test/resources/sql-tests/inputs/cast.sql index 8a035f594be5..049267e83ed1 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/cast.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/cast.sql @@ -60,3 +60,12 @@ DESC FUNCTION EXTENDED boolean; -- cast string to interval and interval to string SELECT CAST('interval 3 month 1 hour' AS interval); SELECT CAST(interval 3 month 1 hour AS string); + +select cast(' 1' as tinyint); +select cast(' 1' as smallint); +select cast(' 1' as INT); +select cast(' 1' as bigint); +select cast(' 1' as float); +select cast(' 1 ' as DOUBLE); +select cast('1.0 ' as DEC); + diff --git a/sql/core/src/test/resources/sql-tests/inputs/comparator.sql b/sql/core/src/test/resources/sql-tests/inputs/comparator.sql index 3e2447723e57..90dfd77c1fcd 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/comparator.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/comparator.sql @@ -1,3 +1,11 @@ -- binary type select x'00' < x'0f'; select x'00' < x'ff'; + +select '1 ' = 1Y; +select '1 ' = 1S; +select '1 ' = 1; +select ' 1' = 1L; +select ' 1' = cast(1.0 as float); +select ' 1.0 ' = 1.0D; +select ' 1.0 ' = 1.0BD; diff --git a/sql/core/src/test/resources/sql-tests/results/cast.sql.out b/sql/core/src/test/resources/sql-tests/results/cast.sql.out index 609d283da555..05b3bec80b0d 100644 --- a/sql/core/src/test/resources/sql-tests/results/cast.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/cast.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 35 +-- Number of queries: 42 -- !query 0 @@ -287,3 +287,59 @@ SELECT CAST(interval 3 month 1 hour AS string) struct -- !query 34 output 3 months 1 hours + + +-- !query 35 +select cast(' 1' as tinyint) +-- !query 35 schema +struct +-- !query 35 output +1 + + +-- !query 36 +select cast(' 1' as smallint) +-- !query 36 schema +struct +-- !query 36 output +1 + + +-- !query 37 +select cast(' 1' as INT) +-- !query 37 schema +struct +-- !query 37 output +1 + + +-- !query 38 +select cast(' 1' as bigint) +-- !query 38 schema +struct +-- !query 38 output +1 + + +-- !query 39 +select cast(' 1' as float) +-- !query 39 schema +struct +-- !query 39 output +1.0 + + +-- !query 40 +select cast(' 1 ' as DOUBLE) +-- !query 40 schema +struct +-- !query 40 output +1.0 + + +-- !query 41 +select cast('1.0 ' as DEC) +-- !query 41 schema +struct +-- !query 41 output +1 diff --git a/sql/core/src/test/resources/sql-tests/results/comparator.sql.out b/sql/core/src/test/resources/sql-tests/results/comparator.sql.out index afc7b5448b7b..f56c0dfacf5a 100644 --- a/sql/core/src/test/resources/sql-tests/results/comparator.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/comparator.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 2 +-- Number of queries: 9 -- !query 0 @@ -16,3 +16,59 @@ select x'00' < x'ff' struct<(X'00' < X'FF'):boolean> -- !query 1 output true + + +-- !query 2 +select '1 ' = 1Y +-- !query 2 schema +struct<(CAST(1 AS TINYINT) = 1):boolean> +-- !query 2 output +true + + +-- !query 3 +select '1 ' = 1S +-- !query 3 schema +struct<(CAST(1 AS SMALLINT) = 1):boolean> +-- !query 3 output +true + + +-- !query 4 +select '1 ' = 1 +-- !query 4 schema +struct<(CAST(1 AS INT) = 1):boolean> +-- !query 4 output +true + + +-- !query 5 +select ' 1' = 1L +-- !query 5 schema +struct<(CAST( 1 AS BIGINT) = 1):boolean> +-- !query 5 output +true + + +-- !query 6 +select ' 1' = cast(1.0 as float) +-- !query 6 schema +struct<(CAST( 1 AS FLOAT) = CAST(1.0 AS FLOAT)):boolean> +-- !query 6 output +true + + +-- !query 7 +select ' 1.0 ' = 1.0D +-- !query 7 schema +struct<(CAST( 1.0 AS DOUBLE) = 1.0):boolean> +-- !query 7 output +true + + +-- !query 8 +select ' 1.0 ' = 1.0BD +-- !query 8 schema +struct<(CAST( 1.0 AS DOUBLE) = CAST(1.0 AS DOUBLE)):boolean> +-- !query 8 output +true