Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import com.google.common.math.LongMath

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts.implicitCast
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer
Expand Down Expand Up @@ -223,6 +224,14 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
def f: (Double) => Double = (x: Double) => 1 / math.tan(x)
testUnary(Cot, f)
checkConsistencyBetweenInterpretedAndCodegen(Cot, DoubleType)
val nullLit = Literal.create(null, NullType)
val intNullLit = Literal.create(null, IntegerType)
val intLit = Literal.create(1, IntegerType)
checkEvaluation(checkDataTypeAndCast(Cot(nullLit)), null, EmptyRow)
checkEvaluation(checkDataTypeAndCast(Cot(intNullLit)), null, EmptyRow)
checkEvaluation(checkDataTypeAndCast(Cot(intLit)), 1 / math.tan(1), EmptyRow)
checkEvaluation(checkDataTypeAndCast(Cot(-intLit)), 1 / math.tan(-1), EmptyRow)
checkEvaluation(checkDataTypeAndCast(Cot(0)), 1 / math.tan(0), EmptyRow)
}

test("atan") {
Expand Down Expand Up @@ -250,6 +259,11 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkConsistencyBetweenInterpretedAndCodegen(Cbrt, DoubleType)
}

def checkDataTypeAndCast(expression: UnaryMathExpression): Expression = {
val expNew = implicitCast(expression.child, expression.inputTypes(0)).getOrElse(expression)
expression.withNewChildren(Seq(expNew))
}

test("ceil") {
testUnary(Ceil, (d: Double) => math.ceil(d).toLong)
checkConsistencyBetweenInterpretedAndCodegen(Ceil, DoubleType)
Expand All @@ -262,12 +276,22 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val doublePi: Double = 3.1415
val floatPi: Float = 3.1415f
val longLit: Long = 12345678901234567L
checkEvaluation(Ceil(doublePi), 4L, EmptyRow)
checkEvaluation(Ceil(floatPi.toDouble), 4L, EmptyRow)
checkEvaluation(Ceil(longLit), longLit, EmptyRow)
checkEvaluation(Ceil(-doublePi), -3L, EmptyRow)
checkEvaluation(Ceil(-floatPi.toDouble), -3L, EmptyRow)
checkEvaluation(Ceil(-longLit), -longLit, EmptyRow)
val nullLit = Literal.create(null, NullType)
val floatNullLit = Literal.create(null, FloatType)
checkEvaluation(checkDataTypeAndCast(Ceil(doublePi)), 4L, EmptyRow)
checkEvaluation(checkDataTypeAndCast(Ceil(floatPi)), 4L, EmptyRow)
checkEvaluation(checkDataTypeAndCast(Ceil(longLit)), longLit, EmptyRow)
checkEvaluation(checkDataTypeAndCast(Ceil(-doublePi)), -3L, EmptyRow)
checkEvaluation(checkDataTypeAndCast(Ceil(-floatPi)), -3L, EmptyRow)
checkEvaluation(checkDataTypeAndCast(Ceil(-longLit)), -longLit, EmptyRow)

checkEvaluation(checkDataTypeAndCast(Ceil(nullLit)), null, EmptyRow)
checkEvaluation(checkDataTypeAndCast(Ceil(floatNullLit)), null, EmptyRow)
checkEvaluation(checkDataTypeAndCast(Ceil(0)), 0L, EmptyRow)
checkEvaluation(checkDataTypeAndCast(Ceil(1)), 1L, EmptyRow)
checkEvaluation(checkDataTypeAndCast(Ceil(1234567890123456L)), 1234567890123456L, EmptyRow)
checkEvaluation(checkDataTypeAndCast(Ceil(0.01)), 1L, EmptyRow)
checkEvaluation(checkDataTypeAndCast(Ceil(-0.10)), 0L, EmptyRow)
}

test("floor") {
Expand All @@ -282,12 +306,22 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val doublePi: Double = 3.1415
val floatPi: Float = 3.1415f
val longLit: Long = 12345678901234567L
checkEvaluation(Floor(doublePi), 3L, EmptyRow)
checkEvaluation(Floor(floatPi.toDouble), 3L, EmptyRow)
checkEvaluation(Floor(longLit), longLit, EmptyRow)
checkEvaluation(Floor(-doublePi), -4L, EmptyRow)
checkEvaluation(Floor(-floatPi.toDouble), -4L, EmptyRow)
checkEvaluation(Floor(-longLit), -longLit, EmptyRow)
val nullLit = Literal.create(null, NullType)
val floatNullLit = Literal.create(null, FloatType)
checkEvaluation(checkDataTypeAndCast(Floor(doublePi)), 3L, EmptyRow)
checkEvaluation(checkDataTypeAndCast(Floor(floatPi)), 3L, EmptyRow)
checkEvaluation(checkDataTypeAndCast(Floor(longLit)), longLit, EmptyRow)
checkEvaluation(checkDataTypeAndCast(Floor(-doublePi)), -4L, EmptyRow)
checkEvaluation(checkDataTypeAndCast(Floor(-floatPi)), -4L, EmptyRow)
checkEvaluation(checkDataTypeAndCast(Floor(-longLit)), -longLit, EmptyRow)

checkEvaluation(checkDataTypeAndCast(Floor(nullLit)), null, EmptyRow)
checkEvaluation(checkDataTypeAndCast(Floor(floatNullLit)), null, EmptyRow)
checkEvaluation(checkDataTypeAndCast(Floor(0)), 0L, EmptyRow)
checkEvaluation(checkDataTypeAndCast(Floor(1)), 1L, EmptyRow)
checkEvaluation(checkDataTypeAndCast(Floor(1234567890123456L)), 1234567890123456L, EmptyRow)
checkEvaluation(checkDataTypeAndCast(Floor(0.01)), 0L, EmptyRow)
checkEvaluation(checkDataTypeAndCast(Floor(-0.10)), -1L, EmptyRow)
}

test("factorial") {
Expand Down Expand Up @@ -541,10 +575,14 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val intPi: Int = 314159265
val longPi: Long = 31415926535897932L
val bdPi: BigDecimal = BigDecimal(31415927L, 7)
val floatPi: Float = 3.1415f

val doubleResults: Seq[Double] = Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142,
3.1416, 3.14159, 3.141593)

val floatResults: Seq[Float] = Seq(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 3.0f, 3.1f, 3.14f,
3.141f, 3.1415f, 3.1415f, 3.1415f)

val shortResults: Seq[Short] = Seq[Short](0, 0, 30000, 31000, 31400, 31420) ++
Seq.fill[Short](7)(31415)

Expand All @@ -563,10 +601,12 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Round(shortPi, scale), shortResults(i), EmptyRow)
checkEvaluation(Round(intPi, scale), intResults(i), EmptyRow)
checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow)
checkEvaluation(Round(floatPi, scale), floatResults(i), EmptyRow)
checkEvaluation(BRound(doublePi, scale), doubleResults(i), EmptyRow)
checkEvaluation(BRound(shortPi, scale), shortResults(i), EmptyRow)
checkEvaluation(BRound(intPi, scale), intResultsB(i), EmptyRow)
checkEvaluation(BRound(longPi, scale), longResults(i), EmptyRow)
checkEvaluation(BRound(floatPi, scale), floatResults(i), EmptyRow)
}

val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14),
Expand Down