From a06c924e50c390cafdea4feb9a7290658a6749a9 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Tue, 16 Jun 2015 12:21:27 +0800 Subject: [PATCH 1/7] Enable NaN for acos & asin, handle -Infinity for log --- .../spark/sql/catalyst/expressions/math.scala | 82 ++++++++++------ .../expressions/MathFunctionsSuite.scala | 95 ++++++++++++++++--- 2 files changed, 137 insertions(+), 40 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 42c596b5b31a..bed7a7d8962a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -66,8 +66,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) if (evalE == null) { null } else { - val result = f(evalE.asInstanceOf[Double]) - if (result.isNaN) null else result + f(evalE.asInstanceOf[Double]) } } @@ -81,9 +80,31 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.primitive} = java.lang.Math.${funcName}(${eval.primitive}); - if (Double.valueOf(${ev.primitive}).isNaN()) { - ${ev.isNull} = true; - } + } + """ + } +} + +abstract class UnaryMathLogExpression(f: Double => Double, name: String) + extends UnaryMathExpression(f, name) { + self: Product => + + override def eval(input: InternalRow): Any = { + val evalE = child.eval(input) + if (evalE == null || evalE.asInstanceOf[Double] <= 0.0) { + null + } else { + f(evalE.asInstanceOf[Double]) + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval = child.gen(ctx) + eval.code + s""" + boolean ${ev.isNull} = ${eval.isNull} || Double.valueOf(${eval.primitive}) <= 0.0; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = java.lang.Math.${funcName}(${eval.primitive}); } """ } @@ -113,8 +134,7 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) if (evalE2 == null) { null } else { - val result = f(evalE1.asInstanceOf[Double], evalE2.asInstanceOf[Double]) - if (result.isNaN) null else result + f(evalE1.asInstanceOf[Double], evalE2.asInstanceOf[Double]) } } } @@ -160,28 +180,45 @@ case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXP case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR") -case class Log(child: Expression) extends UnaryMathExpression(math.log, "LOG") +case class Log(child: Expression) extends UnaryMathLogExpression(math.log, "LOG") case class Log2(child: Expression) - extends UnaryMathExpression((x: Double) => math.log(x) / math.log(2), "LOG2") { + extends UnaryMathLogExpression((x: Double) => math.log(x) / math.log(2), "LOG2") { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval = child.gen(ctx) eval.code + s""" - boolean ${ev.isNull} = ${eval.isNull}; + boolean ${ev.isNull} = ${eval.isNull} || Double.valueOf(${eval.primitive}) <= 0.0; ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.primitive} = java.lang.Math.log(${eval.primitive}) / java.lang.Math.log(2); - if (Double.valueOf(${ev.primitive}).isNaN()) { - ${ev.isNull} = true; - } } """ } } -case class Log10(child: Expression) extends UnaryMathExpression(math.log10, "LOG10") +case class Log10(child: Expression) extends UnaryMathLogExpression(math.log10, "LOG10") + +case class Log1p(child: Expression) extends UnaryMathLogExpression(math.log1p, "LOG1P") { + override def eval(input: InternalRow): Any = { + val evalE = child.eval(input) + if (evalE == null || evalE.asInstanceOf[Double] + 1 <= 0.0) { + null + } else { + math.log1p(evalE.asInstanceOf[Double]) + } + } -case class Log1p(child: Expression) extends UnaryMathExpression(math.log1p, "LOG1P") + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval = child.gen(ctx) + eval.code + s""" + boolean ${ev.isNull} = ${eval.isNull} || Double.valueOf(${eval.primitive}) + 1 <= 0.0; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = java.lang.Math.${funcName}(${eval.primitive}); + } + """ + } +} case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") { override def funcName: String = "rint" @@ -226,19 +263,14 @@ case class Atan2(left: Expression, right: Expression) null } else { // With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0 - val result = math.atan2(evalE1.asInstanceOf[Double] + 0.0, + math.atan2(evalE1.asInstanceOf[Double] + 0.0, evalE2.asInstanceOf[Double] + 0.0) - if (result.isNaN) null else result } } } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") + s""" - if (Double.valueOf(${ev.primitive}).isNaN()) { - ${ev.isNull} = true; - } - """ + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") } } @@ -248,10 +280,6 @@ case class Hypot(left: Expression, right: Expression) case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + s""" - if (Double.valueOf(${ev.primitive}).isNaN()) { - ${ev.isNull} = true; - } - """ + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 864c954ee82c..5dae82da1c84 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -18,6 +18,9 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection} +import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer +import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.types.DoubleType class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -49,11 +52,16 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { c: Expression => Expression, f: T => T, domain: Iterable[T] = (-20 to 20).map(_ * 0.1), - expectNull: Boolean = false): Unit = { + expectNull: Boolean = false, + expectNaN: Boolean = false): Unit = { if (expectNull) { domain.foreach { value => checkEvaluation(c(Literal(value)), null, EmptyRow) } + } else if (expectNaN) { + domain.foreach { value => + checkNaN(c(Literal(value)), EmptyRow) + } } else { domain.foreach { value => checkEvaluation(c(Literal(value)), f(value), EmptyRow) @@ -73,11 +81,16 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { c: (Expression, Expression) => Expression, f: (Double, Double) => Double, domain: Iterable[(Double, Double)] = (-20 to 20).map(v => (v * 0.1, v * -0.1)), - expectNull: Boolean = false): Unit = { + expectNull: Boolean = false, + expectNaN: Boolean = false): Unit = { if (expectNull) { domain.foreach { case (v1, v2) => checkEvaluation(c(Literal(v1), Literal(v2)), null, create_row(null)) } + } else if (expectNaN) { + domain.foreach { case (v1, v2) => + checkNaN(c(Literal(v1), Literal(v2)), create_row(null)) + } } else { domain.foreach { case (v1, v2) => checkEvaluation(c(Literal(v1), Literal(v2)), f(v1 + 0.0, v2 + 0.0), EmptyRow) @@ -88,6 +101,62 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(c(Literal(1.0), Literal.create(null, DoubleType)), null, create_row(null)) } + private def checkNaN( + expression: Expression, inputRow: InternalRow = EmptyRow): Unit = { + checkNaNWithoutCodegen(expression, inputRow) + checkNaNWithGeneratedProjection(expression, inputRow) + checkNaNWithOptimization(expression, inputRow) + } + + private def checkNaNWithoutCodegen( + expression: Expression, + expected: Any, + inputRow: InternalRow = EmptyRow): Unit = { + val actual = try evaluate(expression, inputRow) catch { + case e: Exception => fail(s"Exception evaluating $expression", e) + } + if (!actual.asInstanceOf[Double].isNaN) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect evaluation (codegen off): $expression, " + + s"actual: $actual, " + + s"expected: NaN$input") + } + } + + + private def checkNaNWithGeneratedProjection( + expression: Expression, + inputRow: InternalRow = EmptyRow): Unit = { + + val plan = try { + GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)() + } catch { + case e: Throwable => + val ctx = GenerateProjection.newCodeGenContext() + val evaluated = expression.gen(ctx) + fail( + s""" + |Code generation of $expression failed: + |${evaluated.code} + |$e + """.stripMargin) + } + + val actual = plan(inputRow).apply(0) + if (!actual.asInstanceOf[Double].isNaN) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: NaN$input") + } + } + + private def checkNaNWithOptimization( + expression: Expression, + inputRow: InternalRow = EmptyRow): Unit = { + val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) + val optimizedPlan = DefaultOptimizer.execute(plan) + checkNaNWithoutCodegen(optimizedPlan.expressions.head, inputRow) + } + test("e") { testLeaf(EulerNumber, math.E) } @@ -102,7 +171,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("asin") { testUnary(Asin, math.asin, (-10 to 10).map(_ * 0.1)) - testUnary(Asin, math.asin, (11 to 20).map(_ * 0.1), expectNull = true) + testUnary(Asin, math.asin, (11 to 20).map(_ * 0.1), expectNaN = true) } test("sinh") { @@ -115,7 +184,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("acos") { testUnary(Acos, math.acos, (-10 to 10).map(_ * 0.1)) - testUnary(Acos, math.acos, (11 to 20).map(_ * 0.1), expectNull = true) + testUnary(Acos, math.acos, (11 to 20).map(_ * 0.1), expectNaN = true) } test("cosh") { @@ -171,29 +240,29 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("log") { - testUnary(Log, math.log, (0 to 20).map(_ * 0.1)) - testUnary(Log, math.log, (-5 to -1).map(_ * 0.1), expectNull = true) + testUnary(Log, math.log, (1 to 20).map(_ * 0.1)) + testUnary(Log, math.log, (-5 to 0).map(_ * 0.1), expectNull = true) } test("log10") { - testUnary(Log10, math.log10, (0 to 20).map(_ * 0.1)) - testUnary(Log10, math.log10, (-5 to -1).map(_ * 0.1), expectNull = true) + testUnary(Log10, math.log10, (1 to 20).map(_ * 0.1)) + testUnary(Log10, math.log10, (-5 to 0).map(_ * 0.1), expectNull = true) } test("log1p") { - testUnary(Log1p, math.log1p, (-1 to 20).map(_ * 0.1)) - testUnary(Log1p, math.log1p, (-10 to -2).map(_ * 1.0), expectNull = true) + testUnary(Log1p, math.log1p, (0 to 20).map(_ * 0.1)) + testUnary(Log1p, math.log1p, (-10 to -1).map(_ * 1.0), expectNull = true) } test("log2") { def f: (Double) => Double = (x: Double) => math.log(x) / math.log(2) - testUnary(Log2, f, (0 to 20).map(_ * 0.1)) - testUnary(Log2, f, (-5 to -1).map(_ * 1.0), expectNull = true) + testUnary(Log2, f, (1 to 20).map(_ * 0.1)) + testUnary(Log2, f, (-5 to 0).map(_ * 1.0), expectNull = true) } test("pow") { testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0))) - testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNull = true) + testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNaN = true) } test("hypot") { From 2d1dfc16d1e251cca432b6ef839620ad3e8a9fd5 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Tue, 16 Jun 2015 12:34:02 +0800 Subject: [PATCH 2/7] Enable udf7, udf_acos, udf_asin --- .../spark/sql/hive/execution/HiveCompatibilitySuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 82c0b494598a..b092bc3627fe 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -819,19 +819,19 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf2", "udf5", "udf6", - // "udf7", turn this on after we figure out null vs nan vs infinity + "udf7", "udf8", "udf9", "udf_10_trims", "udf_E", "udf_PI", "udf_abs", - // "udf_acos", turn this on after we figure out null vs nan vs infinity + "udf_acos", "udf_add", "udf_array", "udf_array_contains", "udf_ascii", - // "udf_asin", turn this on after we figure out null vs nan vs infinity + "udf_asin", "udf_atan", "udf_avg", "udf_bigint", From ef8c28d744a4c7c841633c6278096e8564ef09f6 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Tue, 16 Jun 2015 17:42:02 +0800 Subject: [PATCH 3/7] log1p(-1) should be null rather than -Infinity --- .../test/scala/org/apache/spark/sql/MathExpressionsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index faa1d1193b50..48e569526ccc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -69,7 +69,7 @@ class MathExpressionsSuite extends QueryTest { if (f(-1) === math.log1p(-1)) { checkAnswer( nnDoubleData.select(c('b)), - (1 to 9).map(n => Row(f(n * -0.1))) :+ Row(Double.NegativeInfinity) + (1 to 9).map(n => Row(f(n * -0.1))) :+ Row(null) ) } else { checkAnswer( From 4be400a5581fca583a9d2949510b76444cb1de11 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Thu, 18 Jun 2015 23:47:05 +0800 Subject: [PATCH 4/7] Refactor Unary logs as well as fix the binary one --- .../catalyst/analysis/FunctionRegistry.scala | 4 +- .../spark/sql/catalyst/expressions/math.scala | 74 +++++++++---------- .../expressions/MathFunctionsSuite.scala | 1 - .../spark/sql/MathExpressionsSuite.scala | 2 +- 4 files changed, 37 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 13b2bb05f528..944a481f5c5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -114,11 +114,11 @@ object FunctionRegistry { expression[Hypot]("hypot"), expression[Logarithm]("log"), expression[Log]("ln"), - expression[Log10]("log10"), expression[Log1p]("log1p"), + expression[Log10]("log10"), + expression[Log2]("log2"), expression[UnaryMinus]("negative"), expression[Pi]("pi"), - expression[Log2]("log2"), expression[Pow]("pow"), expression[Pow]("power"), expression[UnaryPositive]("positive"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 5aa9f779e38a..ace68648e1b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -85,13 +85,19 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) } } -abstract class UnaryMathLogExpression(f: Double => Double, name: String) +/** + * A expression specifically for unary log functions. + * @param f The math function for non codegen evaluation + * @param name The short name of the log function + * @param yAsymptote values less than or equal to yAsymptote are considered eval to null + */ +abstract class UnaryLogarithmExpression(f: Double => Double, name: String, yAsymptote: Double) extends UnaryMathExpression(f, name) { self: Product => override def eval(input: InternalRow): Any = { val evalE = child.eval(input) - if (evalE == null || evalE.asInstanceOf[Double] <= 0.0) { + if (evalE == null || evalE.asInstanceOf[Double] <= yAsymptote) { null } else { f(evalE.asInstanceOf[Double]) @@ -101,7 +107,7 @@ abstract class UnaryMathLogExpression(f: Double => Double, name: String) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval = child.gen(ctx) eval.code + s""" - boolean ${ev.isNull} = ${eval.isNull} || Double.valueOf(${eval.primitive}) <= 0.0; + boolean ${ev.isNull} = ${eval.isNull} || Double.valueOf(${eval.primitive}) <= $yAsymptote; ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.primitive} = java.lang.Math.${funcName}(${eval.primitive}); @@ -139,8 +145,10 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) } } + def funcName = name.toLowerCase + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.${name.toLowerCase}($c1, $c2)") + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.${funcName}($c1, $c2)") } } @@ -180,41 +188,22 @@ case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXP case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR") -case class Log(child: Expression) extends UnaryMathLogExpression(math.log, "LOG") +case class Log(child: Expression) extends UnaryLogarithmExpression(math.log, "LOG", 0.0) -case class Log2(child: Expression) - extends UnaryMathLogExpression((x: Double) => math.log(x) / math.log(2), "LOG2") { - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val eval = child.gen(ctx) - eval.code + s""" - boolean ${ev.isNull} = ${eval.isNull} || Double.valueOf(${eval.primitive}) <= 0.0; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.primitive} = java.lang.Math.log(${eval.primitive}) / java.lang.Math.log(2); - } - """ - } -} +case class Log10(child: Expression) extends UnaryLogarithmExpression(math.log10, "LOG10", 0.0) -case class Log10(child: Expression) extends UnaryMathLogExpression(math.log10, "LOG10") +case class Log1p(child: Expression) extends UnaryLogarithmExpression(math.log1p, "LOG1P", -1.0) -case class Log1p(child: Expression) extends UnaryMathLogExpression(math.log1p, "LOG1P") { - override def eval(input: InternalRow): Any = { - val evalE = child.eval(input) - if (evalE == null || evalE.asInstanceOf[Double] + 1 <= 0.0) { - null - } else { - math.log1p(evalE.asInstanceOf[Double]) - } - } +case class Log2(child: Expression) + extends UnaryLogarithmExpression((x: Double) => math.log(x) / math.log(2), "LOG2", 0.0) { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval = child.gen(ctx) eval.code + s""" - boolean ${ev.isNull} = ${eval.isNull} || Double.valueOf(${eval.primitive}) + 1 <= 0.0; + boolean ${ev.isNull} = ${eval.isNull} || Double.valueOf(${eval.primitive}) <= 0.0; ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { - ${ev.primitive} = java.lang.Math.${funcName}(${eval.primitive}); + ${ev.primitive} = java.lang.Math.log(${eval.primitive}) / java.lang.Math.log(2); } """ } @@ -286,19 +275,24 @@ case class Pow(left: Expression, right: Expression) case class Logarithm(left: Expression, right: Expression) extends BinaryMathExpression((c1, c2) => math.log(c2) / math.log(c1), "LOG") { - def this(child: Expression) = { - this(EulerNumber(), child) - } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val logCode = if (left.isInstanceOf[EulerNumber]) { - defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2)") - } else { - defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2) / java.lang.Math.log($c1)") - } - logCode + s""" - if (Double.valueOf(${ev.primitive}).isNaN()) { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull} || ${eval1.primitive} <= 0.0; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (${ev.isNull}) { ${ev.isNull} = true; + } else { + ${eval2.code} + if (${eval2.isNull} || ${eval2.primitive} <= 0.0) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = java.lang.Math.${funcName}(${eval2.primitive}) / + java.lang.Math.${funcName}(${eval1.primitive}); + } } """ } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 597cda3b0963..76355103a369 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -280,7 +280,6 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { domain.foreach { case (v1, v2) => checkEvaluation(Logarithm(Literal(v1), Literal(v2)), f(v1 + 0.0, v2 + 0.0), EmptyRow) checkEvaluation(Logarithm(Literal(v2), Literal(v1)), f(v2 + 0.0, v1 + 0.0), EmptyRow) - checkEvaluation(new Logarithm(Literal(v1)), f(math.E, v1 + 0.0), EmptyRow) } checkEvaluation( Logarithm(Literal.create(null, DoubleType), Literal(1.0)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 524336cec737..b42ba221013e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -245,7 +245,7 @@ class MathExpressionsSuite extends QueryTest { Row(math.log(123), math.log(123) / math.log(2), null)) checkAnswer( - df.selectExpr("log(a)", "log(2.0, a)", "log(b)"), + df.selectExpr("ln(a)", "log(2.0, a)", "ln(b)"), Row(math.log(123), math.log(123) / math.log(2), null)) } From 307ba7ee80226b67acd45784b00e1bde0e788623 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Thu, 18 Jun 2015 23:56:36 +0800 Subject: [PATCH 5/7] remove unnecessary Double.valueOf --- .../org/apache/spark/sql/catalyst/expressions/math.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index ace68648e1b1..01d145519303 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -107,7 +107,7 @@ abstract class UnaryLogarithmExpression(f: Double => Double, name: String, yAsym override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval = child.gen(ctx) eval.code + s""" - boolean ${ev.isNull} = ${eval.isNull} || Double.valueOf(${eval.primitive}) <= $yAsymptote; + boolean ${ev.isNull} = ${eval.isNull} || ${eval.primitive} <= $yAsymptote; ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.primitive} = java.lang.Math.${funcName}(${eval.primitive}); @@ -200,7 +200,7 @@ case class Log2(child: Expression) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval = child.gen(ctx) eval.code + s""" - boolean ${ev.isNull} = ${eval.isNull} || Double.valueOf(${eval.primitive}) <= 0.0; + boolean ${ev.isNull} = ${eval.isNull} || ${eval.primitive} <= 0.0; ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.primitive} = java.lang.Math.log(${eval.primitive}) / java.lang.Math.log(2); From a150de5cada684c78547c160d81d10b40708bfa0 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Thu, 18 Jun 2015 23:58:51 +0800 Subject: [PATCH 6/7] Style fix --- .../scala/org/apache/spark/sql/catalyst/expressions/math.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 01d145519303..2fa81aeaf61b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -145,7 +145,7 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) } } - def funcName = name.toLowerCase + def funcName: String = name.toLowerCase override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.${funcName}($c1, $c2)") From f19f651ea0d658a52b7a8f866d7f72d71f91408f Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Fri, 19 Jun 2015 02:07:30 +0800 Subject: [PATCH 7/7] Compare the behavior of binary log with hive --- .../spark/sql/catalyst/expressions/math.scala | 50 +++++++++++++++++++ .../execution/HiveCompatibilitySuite.scala | 2 +- 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 2fa81aeaf61b..e49fc302a874 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -276,6 +276,20 @@ case class Pow(left: Expression, right: Expression) case class Logarithm(left: Expression, right: Expression) extends BinaryMathExpression((c1, c2) => math.log(c2) / math.log(c1), "LOG") { + override def eval(input: InternalRow): Any = { + val evalE1 = left.eval(input) + if (evalE1 == null || evalE1.asInstanceOf[Double] <= 0.0) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null || evalE2.asInstanceOf[Double] <= 0.0) { + null + } else { + math.log(evalE2.asInstanceOf[Double]) / math.log(evalE1.asInstanceOf[Double]) + } + } + } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) @@ -296,4 +310,40 @@ case class Logarithm(left: Expression, right: Expression) } """ } + + // TODO: Hive's UDFLog doesn't support base in range (0.0, 1.0] + // If we want just behaves like Hive, use the code below and turn `udf_7` on + +// override def eval(input: InternalRow): Any = { +// val evalE1 = left.eval(input) +// val evalE2 = right.eval(input) +// if (evalE1 == null || evalE2 == null) { +// null +// } else { +// if (evalE1.asInstanceOf[Double] <= 1.0 || evalE2.asInstanceOf[Double] <= 0.0) { +// null +// } else { +// math.log(evalE2.asInstanceOf[Double]) / math.log(evalE1.asInstanceOf[Double]) +// } +// } +// } +// +// override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { +// val eval1 = left.gen(ctx) +// val eval2 = right.gen(ctx) +// s""" +// ${eval1.code} +// ${eval2.code} +// boolean ${ev.isNull} = ${eval1.isNull} || ${eval2.isNull}; +// ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; +// if (!${ev.isNull}) { +// if (${eval2.primitive} <= 1.0 || ${eval2.primitive} <= 0.0) { +// ${ev.isNull} = true; +// } else { +// ${ev.primitive} = java.lang.Math.${funcName}(${eval2.primitive}) / +// java.lang.Math.${funcName}(${eval1.primitive}); +// } +// } +// """ +// } } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 3abfa6536d98..a50690014709 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -819,7 +819,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf2", "udf5", "udf6", - "udf7", + // "udf7", TODO if we don't allow log base in (0, 1.0], we should turn this on "udf8", "udf9", "udf_10_trims",