diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index b062a0436e43..1edc153f4a13 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -23,6 +23,8 @@ license: | {:toc} ## Upgrading From Spark SQL 2.4 to 3.0 + - Since Spark 3.0, trim the string when casting from string to boolean, date, timestamp or numeric types, whitespace is trimmed from the ends of the value first. + - Since Spark 3.0, PySpark requires a Pandas version of 0.23.2 or higher to use Pandas related functionality, such as `toPandas`, `createDataFrame` from Pandas DataFrame, etc. - Since Spark 3.0, PySpark requires a PyArrow version of 0.12.1 or higher to use PyArrow related functionality, such as `pandas_udf`, `toPandas` and `createDataFrame` with "spark.sql.execution.arrow.enabled=true", etc. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index f8c1102953ab..682f6514f20b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -433,7 +433,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private[this] def castToLong(from: DataType): Any => Any = from match { case StringType => val result = new LongWrapper() - buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else null) + buildCast[UTF8String](_, s => if (s.trim.toLong(result)) result.value else null) case BooleanType => buildCast[Boolean](_, b => if (b) 1L else 0L) case DateType => @@ -448,7 +448,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private[this] def castToInt(from: DataType): Any => Any = from match { case StringType => val result = new IntWrapper() - buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else null) + buildCast[UTF8String](_, s => if (s.trim.toInt(result)) result.value else null) case BooleanType => buildCast[Boolean](_, b => if (b) 1 else 0) case DateType => @@ -463,7 +463,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private[this] def castToShort(from: DataType): Any => Any = from match { case StringType => val result = new IntWrapper() - buildCast[UTF8String](_, s => if (s.toShort(result)) { + buildCast[UTF8String](_, s => if (s.trim.toShort(result)) { result.value.toShort } else { null @@ -482,7 +482,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private[this] def castToByte(from: DataType): Any => Any = from match { case StringType => val result = new IntWrapper() - buildCast[UTF8String](_, s => if (s.toByte(result)) { + buildCast[UTF8String](_, s => if (s.trim.toByte(result)) { result.value.toByte } else { null @@ -518,7 +518,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { case StringType => buildCast[UTF8String](_, s => try { - changePrecision(Decimal(new JavaBigDecimal(s.toString)), target) + changePrecision(Decimal(new JavaBigDecimal(s.toString.trim)), target) } catch { case _: NumberFormatException => null }) @@ -544,7 +544,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String // DoubleConverter private[this] def castToDouble(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, s => try s.toString.toDouble catch { + buildCast[UTF8String](_, s => try s.toString.trim.toDouble catch { case _: NumberFormatException => null }) case BooleanType => @@ -560,7 +560,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String // FloatConverter private[this] def castToFloat(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, s => try s.toString.toFloat catch { + buildCast[UTF8String](_, s => try s.toString.trim.toFloat catch { case _: NumberFormatException => null }) case BooleanType => @@ -983,7 +983,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String (c, evPrim, evNull) => code""" try { - Decimal $tmp = Decimal.apply(new java.math.BigDecimal($c.toString())); + Decimal $tmp = Decimal.apply(new java.math.BigDecimal($c.toString().trim())); ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)} } catch (java.lang.NumberFormatException e) { $evNull = true; @@ -1136,7 +1136,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String (c, evPrim, evNull) => code""" UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); - if ($c.toByte($wrapper)) { + if ($c.trim().toByte($wrapper)) { $evPrim = (byte) $wrapper.value; } else { $evNull = true; @@ -1163,7 +1163,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String (c, evPrim, evNull) => code""" UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); - if ($c.toShort($wrapper)) { + if ($c.trim().toShort($wrapper)) { $evPrim = (short) $wrapper.value; } else { $evNull = true; @@ -1188,7 +1188,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String (c, evPrim, evNull) => code""" UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); - if ($c.toInt($wrapper)) { + if ($c.trim().toInt($wrapper)) { $evPrim = $wrapper.value; } else { $evNull = true; @@ -1214,7 +1214,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String (c, evPrim, evNull) => code""" UTF8String.LongWrapper $wrapper = new UTF8String.LongWrapper(); - if ($c.toLong($wrapper)) { + if ($c.trim().toLong($wrapper)) { $evPrim = $wrapper.value; } else { $evNull = true; @@ -1238,7 +1238,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String (c, evPrim, evNull) => code""" try { - $evPrim = Float.valueOf($c.toString()); + $evPrim = Float.valueOf($c.toString().trim()); } catch (java.lang.NumberFormatException e) { $evNull = true; } @@ -1260,7 +1260,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String (c, evPrim, evNull) => code""" try { - $evPrim = Double.valueOf($c.toString()); + $evPrim = Double.valueOf($c.toString().trim()); } catch (java.lang.NumberFormatException e) { $evNull = true; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala index 6510bacf5589..9bdae538a011 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala @@ -69,8 +69,8 @@ object StringUtils extends Logging { private[this] val falseStrings = Set("f", "false", "n", "no", "0").map(UTF8String.fromString) // scalastyle:off caselocale - def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase) - def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase) + def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.trim.toLowerCase) + def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.trim.toLowerCase) // scalastyle:on caselocale /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 4d667fd61ae0..8b88e7dd021c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import java.math.BigDecimal import java.sql.{Date, Timestamp} import java.util.{Calendar, TimeZone} import java.util.concurrent.TimeUnit._ @@ -1018,4 +1019,27 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(ret, InternalRow(null)) } } + + test("Trim the string when cast string type to Boolean/Numeric types") { + Seq(" true ", " true", "true ").foreach { str => + checkEvaluation(Cast(Literal(str), BooleanType), true) + } + Seq(" false ", " false", "false ").foreach { str => + checkEvaluation(Cast(Literal(str), BooleanType), false) + } + + Seq(" 1 ", " 1", "1 ").foreach { str => + checkEvaluation(Cast(Literal(str), ByteType), 1.toByte) + checkEvaluation(Cast(Literal(str), ShortType), 1.toShort) + checkEvaluation(Cast(Literal(str), IntegerType), 1) + checkEvaluation(Cast(Literal(str), LongType), 1L) + checkEvaluation(Cast(Literal(str), DecimalType.IntDecimal), BigDecimal.ONE) + } + + Seq(" 1.23 ", " 1.23", "1.23 ").foreach { str => + checkEvaluation(Cast(Literal(str), FloatType), 1.23F) + checkEvaluation(Cast(Literal(str), DoubleType), 1.23D) + checkEvaluation(Cast(Literal(str), DecimalType.FloatDecimal), BigDecimal.valueOf(1.2300000)) + } + } } diff --git a/sql/core/src/test/resources/sql-tests/results/pgSQL/boolean.sql.out b/sql/core/src/test/resources/sql-tests/results/pgSQL/boolean.sql.out index 99c42ec2eb6c..3e5b88a3006e 100644 --- a/sql/core/src/test/resources/sql-tests/results/pgSQL/boolean.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pgSQL/boolean.sql.out @@ -39,7 +39,7 @@ SELECT boolean(' f ') AS false -- !query 4 schema struct -- !query 4 output -NULL +false -- !query 5 @@ -296,7 +296,7 @@ SELECT boolean(string(' true ')) AS true, -- !query 36 schema struct -- !query 36 output -NULL NULL +true false -- !query 37