-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-30292][SQL]Throw Exception when invalid string is cast to numeric type in ANSI mode #26933
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
3ed7795
1686c6a
69ee231
74809d0
a336084
c0f8baf
f46181d
d3ffa3c
c7dbeef
7d0faa6
d454452
4b0149c
40afc54
2f845c3
0cb4edc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -482,6 +482,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit | |
|
|
||
| // LongConverter | ||
| private[this] def castToLong(from: DataType): Any => Any = from match { | ||
| case StringType if ansiEnabled => | ||
cloud-fan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| buildCast[UTF8String](_, _.toLongExact()) | ||
| case StringType => | ||
| val result = new LongWrapper() | ||
| buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else null) | ||
|
|
@@ -499,6 +501,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit | |
|
|
||
| // IntConverter | ||
| private[this] def castToInt(from: DataType): Any => Any = from match { | ||
| case StringType if ansiEnabled => | ||
| buildCast[UTF8String](_, _.toIntExact()) | ||
| case StringType => | ||
| val result = new IntWrapper() | ||
| buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else null) | ||
|
|
@@ -518,6 +522,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit | |
|
|
||
| // ShortConverter | ||
| private[this] def castToShort(from: DataType): Any => Any = from match { | ||
| case StringType if ansiEnabled => | ||
| buildCast[UTF8String](_, _.toShortExact()) | ||
| case StringType => | ||
| val result = new IntWrapper() | ||
| buildCast[UTF8String](_, s => if (s.toShort(result)) { | ||
|
|
@@ -559,6 +565,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit | |
|
|
||
| // ByteConverter | ||
| private[this] def castToByte(from: DataType): Any => Any = from match { | ||
| case StringType if ansiEnabled => | ||
| buildCast[UTF8String](_, _.toByteExact()) | ||
| case StringType => | ||
| val result = new IntWrapper() | ||
| buildCast[UTF8String](_, s => if (s.toByte(result)) { | ||
|
|
@@ -636,7 +644,12 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit | |
| // Please refer to https://github.com/apache/spark/pull/26640 | ||
| changePrecision(Decimal(new JavaBigDecimal(s.toString.trim)), target) | ||
| } catch { | ||
| case _: NumberFormatException => null | ||
| case _: NumberFormatException => | ||
| if (ansiEnabled) { | ||
| throw new NumberFormatException(s"invalid input syntax for type numeric: $s") | ||
cloud-fan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } else { | ||
| null | ||
| } | ||
| }) | ||
| case BooleanType => | ||
| buildCast[Boolean](_, b => toPrecision(if (b) Decimal.ONE else Decimal.ZERO, target)) | ||
|
|
@@ -659,6 +672,19 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit | |
|
|
||
| // DoubleConverter | ||
| private[this] def castToDouble(from: DataType): Any => Any = from match { | ||
| case StringType if ansiEnabled => | ||
| buildCast[UTF8String](_, s => { | ||
| val doubleStr = s.toString | ||
| try doubleStr.toDouble catch { | ||
| case _: NumberFormatException => | ||
| val d = Cast.processFloatingPointSpecialLiterals(doubleStr, false) | ||
| if(d == null) { | ||
| throw new NumberFormatException(s"invalid input syntax for type numeric: $s") | ||
| } else { | ||
| d.asInstanceOf[Double].doubleValue() | ||
| } | ||
| } | ||
| }) | ||
| case StringType => | ||
| buildCast[UTF8String](_, s => { | ||
| val doubleStr = s.toString | ||
|
|
@@ -679,6 +705,19 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit | |
|
|
||
| // FloatConverter | ||
| private[this] def castToFloat(from: DataType): Any => Any = from match { | ||
| case StringType if ansiEnabled => | ||
| buildCast[UTF8String](_, s => { | ||
| val floatStr = s.toString | ||
| try floatStr.toFloat catch { | ||
| case _: NumberFormatException => | ||
| val f = Cast.processFloatingPointSpecialLiterals(floatStr, true) | ||
| if (f == null) { | ||
|
||
| throw new NumberFormatException(s"invalid input syntax for type numeric: $s") | ||
| } else { | ||
| f.asInstanceOf[Float].floatValue() | ||
| } | ||
| } | ||
| }) | ||
| case StringType => | ||
| buildCast[UTF8String](_, s => { | ||
| val floatStr = s.toString | ||
|
|
@@ -1133,7 +1172,11 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit | |
| Decimal $tmp = Decimal.apply(new java.math.BigDecimal($c.toString().trim())); | ||
| ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)} | ||
| } catch (java.lang.NumberFormatException e) { | ||
| $evNull = true; | ||
| if ($ansiEnabled) { | ||
cloud-fan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| throw new NumberFormatException("invalid input syntax for type numeric: $c"); | ||
| } else { | ||
| $evNull =true; | ||
| } | ||
| } | ||
| """ | ||
| case BooleanType => | ||
|
|
@@ -1358,13 +1401,20 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit | |
| case StringType => | ||
| val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) | ||
| (c, evPrim, evNull) => | ||
| val casting = if (ansiEnabled) { | ||
| s"$evPrim = $c.toByteExact();" | ||
| } else { | ||
| s""" | ||
| if ($c.toByte($wrapper)) { | ||
| $evPrim = (byte) $wrapper.value; | ||
| } else { | ||
| $evNull = true; | ||
| } | ||
| """ | ||
| } | ||
| code""" | ||
| UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah we don't need to create int wrapper at all for ansi mode
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
| if ($c.toByte($wrapper)) { | ||
| $evPrim = (byte) $wrapper.value; | ||
| } else { | ||
| $evNull = true; | ||
| } | ||
| $casting | ||
| $wrapper = null; | ||
| """ | ||
| case BooleanType => | ||
|
|
@@ -1389,13 +1439,20 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit | |
| case StringType => | ||
| val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) | ||
| (c, evPrim, evNull) => | ||
| val casting = if (ansiEnabled) { | ||
| s"$evPrim = $c.toShortExact();" | ||
| } else { | ||
| s""" | ||
| if ($c.toShort($wrapper)) { | ||
| $evPrim = (short) $wrapper.value; | ||
| } else { | ||
| $evNull = true; | ||
| } | ||
| """ | ||
| } | ||
| code""" | ||
| UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); | ||
| if ($c.toShort($wrapper)) { | ||
| $evPrim = (short) $wrapper.value; | ||
| } else { | ||
| $evNull = true; | ||
| } | ||
| $casting | ||
| $wrapper = null; | ||
| """ | ||
| case BooleanType => | ||
|
|
@@ -1418,13 +1475,20 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit | |
| case StringType => | ||
| val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) | ||
| (c, evPrim, evNull) => | ||
| val casting = if (ansiEnabled) { | ||
| s"$evPrim = $c.toIntExact();" | ||
| } else { | ||
| s""" | ||
| if ($c.toInt($wrapper)) { | ||
| $evPrim = $wrapper.value; | ||
| } else { | ||
| $evNull = true; | ||
| } | ||
| """ | ||
| } | ||
| code""" | ||
| UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); | ||
| if ($c.toInt($wrapper)) { | ||
| $evPrim = $wrapper.value; | ||
| } else { | ||
| $evNull = true; | ||
| } | ||
| $casting | ||
| $wrapper = null; | ||
| """ | ||
| case BooleanType => | ||
|
|
@@ -1445,15 +1509,21 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit | |
| private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match { | ||
| case StringType => | ||
| val wrapper = ctx.freshVariable("longWrapper", classOf[UTF8String.LongWrapper]) | ||
|
|
||
| (c, evPrim, evNull) => | ||
| val casting = if (ansiEnabled) { | ||
| s"$evPrim = $c.toLongExact();" | ||
| } else { | ||
| s""" | ||
| if ($c.toLong($wrapper)) { | ||
| $evPrim = $wrapper.value; | ||
| } else { | ||
| $evNull = true; | ||
| } | ||
| """ | ||
| } | ||
| code""" | ||
| UTF8String.LongWrapper $wrapper = new UTF8String.LongWrapper(); | ||
| if ($c.toLong($wrapper)) { | ||
| $evPrim = $wrapper.value; | ||
| } else { | ||
| $evNull = true; | ||
| } | ||
| $casting | ||
| $wrapper = null; | ||
| """ | ||
| case BooleanType => | ||
|
|
@@ -1473,6 +1543,22 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit | |
|
|
||
| private[this] def castToFloatCode(from: DataType, ctx: CodegenContext): CastFunction = { | ||
| from match { | ||
| case StringType if ansiEnabled => | ||
| val floatStr = ctx.freshVariable("floatStr", StringType) | ||
| (c, evPrim, evNull) => | ||
| code""" | ||
| final String $floatStr = $c.toString(); | ||
| try { | ||
| $evPrim = Float.valueOf($floatStr); | ||
| } catch (java.lang.NumberFormatException e) { | ||
| final Float f = (Float) Cast.processFloatingPointSpecialLiterals($floatStr, true); | ||
| if (f == null) { | ||
|
||
| throw new NumberFormatException("invalid input syntax for type numeric: $c"); | ||
| } else { | ||
| $evPrim = f.floatValue(); | ||
| } | ||
| } | ||
| """ | ||
| case StringType => | ||
| val floatStr = ctx.freshVariable("floatStr", StringType) | ||
| (c, evPrim, evNull) => | ||
|
|
@@ -1504,6 +1590,22 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit | |
|
|
||
| private[this] def castToDoubleCode(from: DataType, ctx: CodegenContext): CastFunction = { | ||
| from match { | ||
| case StringType if ansiEnabled => | ||
| val doubleStr = ctx.freshVariable("doubleStr", StringType) | ||
| (c, evPrim, evNull) => | ||
| code""" | ||
| final String $doubleStr = $c.toString(); | ||
| try { | ||
| $evPrim = Double.valueOf($doubleStr); | ||
| } catch (java.lang.NumberFormatException e) { | ||
| final Double d = (Double) Cast.processFloatingPointSpecialLiterals($doubleStr, false); | ||
| if (d == null) { | ||
| throw new NumberFormatException("invalid input syntax for type numeric: $c"); | ||
| } else { | ||
| $evPrim = d.doubleValue(); | ||
| } | ||
| } | ||
| """ | ||
| case StringType => | ||
| val doubleStr = ctx.freshVariable("doubleStr", StringType) | ||
| (c, evPrim, evNull) => | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.