Skip to content
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ case class Cast(child: Expression, dataType: DataType)
case ByteType =>
buildCast[Byte](_, _ != 0)
case DecimalType() =>
buildCast[Decimal](_, _ != Decimal.ZERO)
buildCast[Decimal](_, !_.isZero)
case DoubleType =>
buildCast[Double](_, _ != 0)
case FloatType =>
Expand Down Expand Up @@ -190,7 +190,7 @@ case class Cast(child: Expression, dataType: DataType)
}

private[this] def decimalToTimestamp(d: Decimal): Long = {
(d.toBigDecimal * 1000000L).longValue()
d.toJavaBigDecimal.multiply(java.math.BigDecimal.valueOf(1000000L)).longValue()
}
private[this] def doubleToTimestamp(d: Double): Any = {
if (d.isNaN || d.isInfinite) null else (d * 1000000L).toLong
Expand Down Expand Up @@ -534,10 +534,7 @@ case class Cast(child: Expression, dataType: DataType)
(c, evPrim, evNull) =>
s"""
try {
org.apache.spark.sql.types.Decimal tmpDecimal =
new org.apache.spark.sql.types.Decimal().set(
new scala.math.BigDecimal(
new java.math.BigDecimal($c.toString())));
Decimal tmpDecimal = Decimal.apply(new java.math.BigDecimal($c.toString()));
${changePrecision("tmpDecimal", target, evPrim, evNull)}
} catch (java.lang.NumberFormatException e) {
$evNull = true;
Expand All @@ -546,11 +543,11 @@ case class Cast(child: Expression, dataType: DataType)
case BooleanType =>
(c, evPrim, evNull) =>
s"""
org.apache.spark.sql.types.Decimal tmpDecimal = null;
Decimal tmpDecimal = null;
if ($c) {
tmpDecimal = new org.apache.spark.sql.types.Decimal().set(1);
tmpDecimal = new Decimal().set(1);
} else {
tmpDecimal = new org.apache.spark.sql.types.Decimal().set(0);
tmpDecimal = new Decimal().set(0);
}
${changePrecision("tmpDecimal", target, evPrim, evNull)}
"""
Expand All @@ -561,32 +558,28 @@ case class Cast(child: Expression, dataType: DataType)
// Note that we lose precision here.
(c, evPrim, evNull) =>
s"""
org.apache.spark.sql.types.Decimal tmpDecimal =
new org.apache.spark.sql.types.Decimal().set(
scala.math.BigDecimal.valueOf(${timestampToDoubleCode(c)}));
Decimal tmpDecimal = Decimal.apply(
java.math.BigDecimal.valueOf(${timestampToDoubleCode(c)}));
${changePrecision("tmpDecimal", target, evPrim, evNull)}
"""
case DecimalType() =>
case dt: DecimalType =>
(c, evPrim, evNull) =>
s"""
org.apache.spark.sql.types.Decimal tmpDecimal = $c.clone();
Decimal tmpDecimal = $c.clone();
${changePrecision("tmpDecimal", target, evPrim, evNull)}
"""
case LongType =>
case ByteType | ShortType | IntegerType | LongType =>

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this change ends up avoiding an unnecessary integral to floating point cast. Nice!

(c, evPrim, evNull) =>
s"""
org.apache.spark.sql.types.Decimal tmpDecimal =
new org.apache.spark.sql.types.Decimal().set($c);
Decimal tmpDecimal = Decimal.apply((long) $c);
${changePrecision("tmpDecimal", target, evPrim, evNull)}
"""
case x: NumericType =>
// All other numeric types can be represented precisely as Doubles
(c, evPrim, evNull) =>
s"""
try {
org.apache.spark.sql.types.Decimal tmpDecimal =
new org.apache.spark.sql.types.Decimal().set(
scala.math.BigDecimal.valueOf((double) $c));
Decimal tmpDecimal = Decimal.apply(java.math.BigDecimal.valueOf((double) $c));
${changePrecision("tmpDecimal", target, evPrim, evNull)}
} catch (java.lang.NumberFormatException e) {
$evNull = true;
Expand Down Expand Up @@ -646,7 +639,7 @@ case class Cast(child: Expression, dataType: DataType)
}

private[this] def decimalToTimestampCode(d: String): String =
s"($d.toBigDecimal().bigDecimal().multiply(new java.math.BigDecimal(1000000L))).longValue()"
s"($d.toJavaBigDecimal().multiply(new java.math.BigDecimal(1000000L))).longValue()"
private[this] def longToTimeStampCode(l: String): String = s"$l * 1000L"
private[this] def timestampToIntegerCode(ts: String): String =
s"java.lang.Math.floor((double) $ts / 1000000L)"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
val decimalAdd = "$plus"
s"""
${ctx.javaType(dataType)} r = $eval1.remainder($eval2);
if (r.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) {
if (r.compare(Decimal.ZERO()) < 0) {
${ev.primitive} = (r.$decimalAdd($eval2)).remainder($eval2);
} else {
${ev.primitive} = r;
Expand Down
Loading