-
Notifications
You must be signed in to change notification settings - Fork 29.1k
[SPARK-27768][SQL] Support Infinity/NaN-related float/double literals case-insensitively #25331
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1011c2c
19957ac
6ce5094
5b3b734
48795b1
e901dc4
41baaa0
429048b
1a5d978
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions | |
|
|
||
| import java.math.{BigDecimal => JavaBigDecimal} | ||
| import java.time.ZoneId | ||
| import java.util.Locale | ||
| import java.util.concurrent.TimeUnit._ | ||
|
|
||
| import org.apache.spark.SparkException | ||
|
|
@@ -192,6 +193,22 @@ object Cast { | |
| } | ||
|
|
||
| def resolvableNullability(from: Boolean, to: Boolean): Boolean = !from || to | ||
|
|
||
| /** | ||
| * We process literals such as 'Infinity', 'Inf', '-Infinity' and 'NaN' etc in case | ||
| * insensitive manner to be compatible with other database systems such as PostgreSQL and DB2. | ||
| */ | ||
| def processFloatingPointSpecialLiterals(v: String, isFloat: Boolean): Any = { | ||
| v.trim.toLowerCase(Locale.ROOT) match { | ||
| case "inf" | "+inf" | "infinity" | "+infinity" => | ||
| if (isFloat) Float.PositiveInfinity else Double.PositiveInfinity | ||
| case "-inf" | "-infinity" => | ||
| if (isFloat) Float.NegativeInfinity else Double.NegativeInfinity | ||
| case "nan" => | ||
| if (isFloat) Float.NaN else Double.NaN | ||
| case _ => null | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -562,8 +579,12 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String | |
| // DoubleConverter | ||
| private[this] def castToDouble(from: DataType): Any => Any = from match { | ||
| case StringType => | ||
| buildCast[UTF8String](_, s => try s.toString.toDouble catch { | ||
| case _: NumberFormatException => null | ||
| buildCast[UTF8String](_, s => { | ||
| val doubleStr = s.toString | ||
| try doubleStr.toDouble catch { | ||
| case _: NumberFormatException => | ||
| Cast.processFloatingPointSpecialLiterals(doubleStr, false) | ||
| } | ||
| }) | ||
| case BooleanType => | ||
| buildCast[Boolean](_, b => if (b) 1d else 0d) | ||
|
|
@@ -578,8 +599,12 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String | |
| // FloatConverter | ||
| private[this] def castToFloat(from: DataType): Any => Any = from match { | ||
| case StringType => | ||
| buildCast[UTF8String](_, s => try s.toString.toFloat catch { | ||
| case _: NumberFormatException => null | ||
| buildCast[UTF8String](_, s => { | ||
| val floatStr = s.toString | ||
| try floatStr.toFloat catch { | ||
| case _: NumberFormatException => | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we explicitly match strings we want? We can avoid try-catch of an exception for the known strings.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @viirya So my idea here was that.. since we don't expect these kind of literals a lot i.e its not a common case .. we don't change our normal processing path to add any possible runtime costs. Thats why we keep all these in the exception processing.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like the @viirya approach for readability, but I understand your concern for the performance. So, could you check actual performance changes?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @maropu Oops.. didn't see this comment. I suppose i have to use the benchmark framework for this ? Appreciate any tip on this..
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we check the numbers by a simple query?, e.g., In another pr, I observed that a logic depending on exceptions cause high performance penalties: lz4/lz4-java#143
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @maropu Actually my input data didn't contain any of these special literals. Basically i was testing with the condition when we don't hit the catch block. Basically we are trying optimize for the best case ?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just make sure that a small number of
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @maropu In the mean time i had tried with 1% of data being 'inf' and i can confirm that it does not hurt the performance :-)
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds nice, too ;) |
||
| Cast.processFloatingPointSpecialLiterals(floatStr, true) | ||
| } | ||
| }) | ||
| case BooleanType => | ||
| buildCast[Boolean](_, b => if (b) 1f else 0f) | ||
|
|
@@ -717,9 +742,9 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String | |
| case ByteType => castToByteCode(from, ctx) | ||
| case ShortType => castToShortCode(from, ctx) | ||
| case IntegerType => castToIntCode(from, ctx) | ||
| case FloatType => castToFloatCode(from) | ||
| case FloatType => castToFloatCode(from, ctx) | ||
| case LongType => castToLongCode(from, ctx) | ||
| case DoubleType => castToDoubleCode(from) | ||
| case DoubleType => castToDoubleCode(from, ctx) | ||
|
|
||
| case array: ArrayType => | ||
| castArrayCode(from.asInstanceOf[ArrayType].elementType, array.elementType, ctx) | ||
|
|
@@ -1259,48 +1284,66 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String | |
| (c, evPrim, evNull) => code"$evPrim = (long) $c;" | ||
| } | ||
|
|
||
| private[this] def castToFloatCode(from: DataType): CastFunction = from match { | ||
| case StringType => | ||
| (c, evPrim, evNull) => | ||
| code""" | ||
| private[this] def castToFloatCode(from: DataType, ctx: CodegenContext): CastFunction = { | ||
| from match { | ||
| case StringType => | ||
| val floatStr = ctx.freshVariable("floatStr", StringType) | ||
| (c, evPrim, evNull) => | ||
| code""" | ||
| final String $floatStr = $c.toString(); | ||
| try { | ||
| $evPrim = Float.valueOf($c.toString()); | ||
| $evPrim = Float.valueOf($floatStr); | ||
| } catch (java.lang.NumberFormatException e) { | ||
| $evNull = true; | ||
| final Float f = (Float) Cast.processFloatingPointSpecialLiterals($floatStr, true); | ||
| if (f == null) { | ||
| $evNull = true; | ||
| } else { | ||
| $evPrim = f.floatValue(); | ||
| } | ||
| } | ||
| """ | ||
| case BooleanType => | ||
| (c, evPrim, evNull) => code"$evPrim = $c ? 1.0f : 0.0f;" | ||
| case DateType => | ||
| (c, evPrim, evNull) => code"$evNull = true;" | ||
| case TimestampType => | ||
| (c, evPrim, evNull) => code"$evPrim = (float) (${timestampToDoubleCode(c)});" | ||
| case DecimalType() => | ||
| (c, evPrim, evNull) => code"$evPrim = $c.toFloat();" | ||
| case x: NumericType => | ||
| (c, evPrim, evNull) => code"$evPrim = (float) $c;" | ||
| case BooleanType => | ||
| (c, evPrim, evNull) => code"$evPrim = $c ? 1.0f : 0.0f;" | ||
| case DateType => | ||
| (c, evPrim, evNull) => code"$evNull = true;" | ||
| case TimestampType => | ||
| (c, evPrim, evNull) => code"$evPrim = (float) (${timestampToDoubleCode(c)});" | ||
| case DecimalType() => | ||
| (c, evPrim, evNull) => code"$evPrim = $c.toFloat();" | ||
| case x: NumericType => | ||
| (c, evPrim, evNull) => code"$evPrim = (float) $c;" | ||
| } | ||
| } | ||
|
|
||
| private[this] def castToDoubleCode(from: DataType): CastFunction = from match { | ||
| case StringType => | ||
| (c, evPrim, evNull) => | ||
| code""" | ||
| private[this] def castToDoubleCode(from: DataType, ctx: CodegenContext): CastFunction = { | ||
| from match { | ||
| case StringType => | ||
| val doubleStr = ctx.freshVariable("doubleStr", StringType) | ||
| (c, evPrim, evNull) => | ||
| code""" | ||
| final String $doubleStr = $c.toString(); | ||
| try { | ||
| $evPrim = Double.valueOf($c.toString()); | ||
| $evPrim = Double.valueOf($doubleStr); | ||
| } catch (java.lang.NumberFormatException e) { | ||
| $evNull = true; | ||
| final Double d = (Double) Cast.processFloatingPointSpecialLiterals($doubleStr, false); | ||
| if (d == null) { | ||
| $evNull = true; | ||
| } else { | ||
| $evPrim = d.doubleValue(); | ||
| } | ||
| } | ||
| """ | ||
| case BooleanType => | ||
| (c, evPrim, evNull) => code"$evPrim = $c ? 1.0d : 0.0d;" | ||
| case DateType => | ||
| (c, evPrim, evNull) => code"$evNull = true;" | ||
| case TimestampType => | ||
| (c, evPrim, evNull) => code"$evPrim = ${timestampToDoubleCode(c)};" | ||
| case DecimalType() => | ||
| (c, evPrim, evNull) => code"$evPrim = $c.toDouble();" | ||
| case x: NumericType => | ||
| (c, evPrim, evNull) => code"$evPrim = (double) $c;" | ||
| case BooleanType => | ||
| (c, evPrim, evNull) => code"$evPrim = $c ? 1.0d : 0.0d;" | ||
| case DateType => | ||
| (c, evPrim, evNull) => code"$evNull = true;" | ||
| case TimestampType => | ||
| (c, evPrim, evNull) => code"$evPrim = ${timestampToDoubleCode(c)};" | ||
| case DecimalType() => | ||
| (c, evPrim, evNull) => code"$evPrim = $c.toDouble();" | ||
| case x: NumericType => | ||
| (c, evPrim, evNull) => code"$evPrim = (double) $c;" | ||
| } | ||
| } | ||
|
|
||
| private[this] def castArrayCode( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually some literals like InFiniTy can be casted, it looks a bit weird. However, postgresql accepts such things. Maybe add a comment here to explain why allowing case insensitive match?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@viirya Does this look ok ?
We process literals such as 'Infinity', 'Inf', '-Infinity' and 'NaN' in case insensitive manner to be compatible with other database systems such as Postgres and DB2.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel ok. Other reviewers may also have advice too.