Skip to content
Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1294,6 +1294,52 @@ public boolean toByte(IntWrapper intWrapper) {
return false;
}

/**
* Parses UTF8String(trimmed if needed) to long. This method is used when ANSI is enabled.
*
* @return If string contains valid numeric value then it returns the long value otherwise a
* NumberFormatException is thrown.
*/
public long toLongExact() {
LongWrapper result = new LongWrapper();
if (toLong(result)) {
return result.value;
}
throw new NumberFormatException("invalid input syntax for type numeric: " + this);
}

/**
* Parses UTF8String(trimmed if needed) to int. This method is used when ANSI is enabled.
*
* @return If string contains valid numeric value then it returns the int value otherwise a
* NumberFormatException is thrown.
*/
public int toIntExact() {
IntWrapper result = new IntWrapper();
if (toInt(result)) {
return result.value;
}
throw new NumberFormatException("invalid input syntax for type numeric: " + this);
}

public short toShortExact() {
int value = this.toIntExact();
short result = (short) value;
if (result == value) {
return result;
}
throw new NumberFormatException("invalid input syntax for type numeric: " + this);
}

public byte toByteExact() {
int value = this.toIntExact();
byte result = (byte) value;
if (result == value) {
return result;
}
throw new NumberFormatException("invalid input syntax for type numeric: " + this);
}

@Override
public String toString() {
return new String(getBytes(), StandardCharsets.UTF_8);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit

// LongConverter
private[this] def castToLong(from: DataType): Any => Any = from match {
case StringType if ansiEnabled =>
buildCast[UTF8String](_, _.toLongExact())
case StringType =>
val result = new LongWrapper()
buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else null)
Expand All @@ -499,6 +501,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit

// IntConverter
private[this] def castToInt(from: DataType): Any => Any = from match {
case StringType if ansiEnabled =>
buildCast[UTF8String](_, _.toIntExact())
case StringType =>
val result = new IntWrapper()
buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else null)
Expand All @@ -518,6 +522,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit

// ShortConverter
private[this] def castToShort(from: DataType): Any => Any = from match {
case StringType if ansiEnabled =>
buildCast[UTF8String](_, _.toShortExact())
case StringType =>
val result = new IntWrapper()
buildCast[UTF8String](_, s => if (s.toShort(result)) {
Expand Down Expand Up @@ -559,6 +565,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit

// ByteConverter
private[this] def castToByte(from: DataType): Any => Any = from match {
case StringType if ansiEnabled =>
buildCast[UTF8String](_, _.toByteExact())
case StringType =>
val result = new IntWrapper()
buildCast[UTF8String](_, s => if (s.toByte(result)) {
Expand Down Expand Up @@ -636,7 +644,12 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
// Please refer to https://github.com/apache/spark/pull/26640
changePrecision(Decimal(new JavaBigDecimal(s.toString.trim)), target)
} catch {
case _: NumberFormatException => null
case _: NumberFormatException =>
if (ansiEnabled) {
throw new NumberFormatException(s"invalid input syntax for type numeric: $s")
} else {
null
}
})
case BooleanType =>
buildCast[Boolean](_, b => toPrecision(if (b) Decimal.ONE else Decimal.ZERO, target))
Expand All @@ -659,6 +672,19 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit

// DoubleConverter
private[this] def castToDouble(from: DataType): Any => Any = from match {
case StringType if ansiEnabled =>
buildCast[UTF8String](_, s => {
val doubleStr = s.toString
try doubleStr.toDouble catch {
case _: NumberFormatException =>
val d = Cast.processFloatingPointSpecialLiterals(doubleStr, false)
if(d == null) {
throw new NumberFormatException(s"invalid input syntax for type numeric: $s")
} else {
d.asInstanceOf[Double].doubleValue()
}
}
})
case StringType =>
buildCast[UTF8String](_, s => {
val doubleStr = s.toString
Expand All @@ -679,6 +705,19 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit

// FloatConverter
private[this] def castToFloat(from: DataType): Any => Any = from match {
case StringType if ansiEnabled =>
buildCast[UTF8String](_, s => {
val floatStr = s.toString
try floatStr.toFloat catch {
case _: NumberFormatException =>
val f = Cast.processFloatingPointSpecialLiterals(floatStr, true)
if (f == null) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is too much code duplication. How about unifying these 2 cases?

val f = Cast.processFloatingPointSpecialLiterals(floatStr, true)
if (f == null && ansiEnabled) {
  throw ...
} else {
  f
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

throw new NumberFormatException(s"invalid input syntax for type numeric: $s")
} else {
f.asInstanceOf[Float].floatValue()
}
}
})
case StringType =>
buildCast[UTF8String](_, s => {
val floatStr = s.toString
Expand Down Expand Up @@ -1133,7 +1172,11 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
Decimal $tmp = Decimal.apply(new java.math.BigDecimal($c.toString().trim()));
${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)}
} catch (java.lang.NumberFormatException e) {
$evNull = true;
if ($ansiEnabled) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this will generate java code with if-else, we can do better

val handleException = if (ansiEnabled) {
  s"throw new NumberFormatException("invalid input syntax for type numeric: $c");"
} else {
  s"$evNull =true;"
}
code"""
  ...
  } catch (java.lang.NumberFormatException e) {
    $handleException
  }
"""

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

throw new NumberFormatException("invalid input syntax for type numeric: $c");
} else {
$evNull =true;
}
}
"""
case BooleanType =>
Expand Down Expand Up @@ -1358,13 +1401,20 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
case StringType =>
val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper])
(c, evPrim, evNull) =>
val casting = if (ansiEnabled) {
s"$evPrim = $c.toByteExact();"
} else {
s"""
if ($c.toByte($wrapper)) {
$evPrim = (byte) $wrapper.value;
} else {
$evNull = true;
}
"""
}
code"""
UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah we don't need to create int wrapper at all for ansi mode

case StringType if ansi => (c, evPrim, evNull) => s"$evPrim = $c.toByteExact();"
case StringType => // the original code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

if ($c.toByte($wrapper)) {
$evPrim = (byte) $wrapper.value;
} else {
$evNull = true;
}
$casting
$wrapper = null;
"""
case BooleanType =>
Expand All @@ -1389,13 +1439,20 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
case StringType =>
val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper])
(c, evPrim, evNull) =>
val casting = if (ansiEnabled) {
s"$evPrim = $c.toShortExact();"
} else {
s"""
if ($c.toShort($wrapper)) {
$evPrim = (short) $wrapper.value;
} else {
$evNull = true;
}
"""
}
code"""
UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper();
if ($c.toShort($wrapper)) {
$evPrim = (short) $wrapper.value;
} else {
$evNull = true;
}
$casting
$wrapper = null;
"""
case BooleanType =>
Expand All @@ -1418,13 +1475,20 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
case StringType =>
val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper])
(c, evPrim, evNull) =>
val casting = if (ansiEnabled) {
s"$evPrim = $c.toIntExact();"
} else {
s"""
if ($c.toInt($wrapper)) {
$evPrim = $wrapper.value;
} else {
$evNull = true;
}
"""
}
code"""
UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper();
if ($c.toInt($wrapper)) {
$evPrim = $wrapper.value;
} else {
$evNull = true;
}
$casting
$wrapper = null;
"""
case BooleanType =>
Expand All @@ -1445,15 +1509,21 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
case StringType =>
val wrapper = ctx.freshVariable("longWrapper", classOf[UTF8String.LongWrapper])

(c, evPrim, evNull) =>
val casting = if (ansiEnabled) {
s"$evPrim = $c.toLongExact();"
} else {
s"""
if ($c.toLong($wrapper)) {
$evPrim = $wrapper.value;
} else {
$evNull = true;
}
"""
}
code"""
UTF8String.LongWrapper $wrapper = new UTF8String.LongWrapper();
if ($c.toLong($wrapper)) {
$evPrim = $wrapper.value;
} else {
$evNull = true;
}
$casting
$wrapper = null;
"""
case BooleanType =>
Expand All @@ -1473,6 +1543,22 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit

private[this] def castToFloatCode(from: DataType, ctx: CodegenContext): CastFunction = {
from match {
case StringType if ansiEnabled =>
val floatStr = ctx.freshVariable("floatStr", StringType)
(c, evPrim, evNull) =>
code"""
final String $floatStr = $c.toString();
try {
$evPrim = Float.valueOf($floatStr);
} catch (java.lang.NumberFormatException e) {
final Float f = (Float) Cast.processFloatingPointSpecialLiterals($floatStr, true);
if (f == null) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can unify the code a little bit

val handleNull = if (ansiEnabled) {
  s"throw ..."
} else {
  s"$evNull = true;"
}
...
code"""
  ...
  if (f == null) {
    $handleNull
  } else ...
"""

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

throw new NumberFormatException("invalid input syntax for type numeric: $c");
} else {
$evPrim = f.floatValue();
}
}
"""
case StringType =>
val floatStr = ctx.freshVariable("floatStr", StringType)
(c, evPrim, evNull) =>
Expand Down Expand Up @@ -1504,6 +1590,22 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit

private[this] def castToDoubleCode(from: DataType, ctx: CodegenContext): CastFunction = {
from match {
case StringType if ansiEnabled =>
val doubleStr = ctx.freshVariable("doubleStr", StringType)
(c, evPrim, evNull) =>
code"""
final String $doubleStr = $c.toString();
try {
$evPrim = Double.valueOf($doubleStr);
} catch (java.lang.NumberFormatException e) {
final Double d = (Double) Cast.processFloatingPointSpecialLiterals($doubleStr, false);
if (d == null) {
throw new NumberFormatException("invalid input syntax for type numeric: $c");
} else {
$evPrim = d.doubleValue();
}
}
"""
case StringType =>
val doubleStr = ctx.freshVariable("doubleStr", StringType)
(c, evPrim, evNull) =>
Expand Down
Loading