Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 72 additions & 5 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,36 @@
__all__ = ['countDistinct', 'approxCountDistinct', 'udf']


def _create_function(name, doc=""):
def _function_obj(sc, is_math=False):
if not is_math:
return sc._jvm.functions
else:
return sc._jvm.mathfunctions


def _create_function(name, doc="", is_math=False, binary=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if you got my previous comment. it might be easier if is_math just takes a jvm object, rather than using an extra _function_obj

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did... The problem is, I can't pass in sc when it was outside this function. I guess the problem was globally calling sc = SparkContext._active_spark_context, meaning outside of a function. That's why I had to have this hack.

""" Create a function for aggregator by name"""
def _(col):
def _(col1, col2=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is somewhat strange to have col2 be default None. I think it's easier if we just create a _create_binary_function function.

sc = SparkContext._active_spark_context
jc = getattr(sc._jvm.functions, name)(col._jc if isinstance(col, Column) else col)
if not binary:
jc = getattr(_function_obj(sc, is_math), name)(col1._jc if isinstance(col1, Column)
else col1)
else:
assert col2, "The second column for %s not provided!" % name
# users might write ints for simplicity. This would throw an error on the JVM side.
if type(col1) is int:
col1 = col1 * 1.0
if type(col2) is int:
col2 = col2 * 1.0
jc = getattr(_function_obj(sc, is_math), name)(col1._jc if isinstance(col1, Column)
else col1,
col2._jc if isinstance(col2, Column)
else col2)
return Column(jc)
_.__name__ = name
_.__doc__ = doc
return _


_functions = {
'lit': 'Creates a :class:`Column` of literal value.',
'col': 'Returns a :class:`Column` based on the given column name.',
Expand All @@ -54,7 +73,7 @@ def _(col):
'upper': 'Converts a string expression to upper case.',
'lower': 'Converts a string expression to upper case.',
'sqrt': 'Computes the square root of the specified float value.',
'abs': 'Computes the absolutle value.',
'abs': 'Computes the absolute value.',

'max': 'Aggregate function: returns the maximum value of the expression in a group.',
'min': 'Aggregate function: returns the minimum value of the expression in a group.',
Expand All @@ -67,11 +86,59 @@ def _(col):
'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.',
}

# math functions are found under another object therefore, they need to be handled separately
_math_functions = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually we should probably make this a parallel thing, i.e. mathfunctions?

'acos': 'Computes the cosine inverse of the given value; the returned angle is in the range' +
'0.0 through pi.',
'asin': 'Computes the sine inverse of the given value; the returned angle is in the range' +
'-pi/2 through pi/2.',
'atan': 'Computes the tangent inverse of the given value.',
'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' +
'polar coordinates (r, theta).',
'cbrt': 'Computes the cube-root of the given value.',
'ceil': 'Computes the ceiling of the given value.',
'cos': 'Computes the cosine of the given value.',
'cosh': 'Computes the hyperbolic cosine of the given value.',
'exp': 'Computes the exponential of the given value.',
'expm1': 'Computes the exponential of the given value minus one.',
'floor': 'Computes the floor of the given value.',
'hypot': 'Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow.',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think u need to special case for binary columns. and add some tests for some functions in python sql tests.

'log': 'Computes the natural logarithm of the given value.',
'log10': 'Computes the logarithm of the given value in Base 10.',
'log1p': 'Computes the natural logarithm of the given value plus one.',
'pow': 'Returns the value of the first argument raised to the power of the second argument.',
'rint': 'Returns the double value that is closest in value to the argument and' +
' is equal to a mathematical integer.',
'signum': 'Computes the signum of the given value.',
'sin': 'Computes the sine of the given value.',
'sinh': 'Computes the hyperbolic sine of the given value.',
'tan': 'Computes the tangent of the given value.',
'tanh': 'Computes the hyperbolic tangent of the given value.',
'toDeg': 'Converts an angle measured in radians to an approximately equivalent angle ' +
'measured in degrees.',
'toRad': 'Converts an angle measured in degrees to an approximately equivalent angle ' +
'measured in radians.'
}

# math functions that take two arguments as input
_binary_math_functions = {
'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' +
'polar coordinates (r, theta).',
'hypot': 'Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow.',
'pow': 'Returns the value of the first argument raised to the power of the second argument.'
}


for _name, _doc in _functions.items():
globals()[_name] = _create_function(_name, _doc)
for _name, _doc in _math_functions.items():
globals()[_name] = _create_function(_name, _doc, True)
for _name, _doc in _binary_math_functions.items():
globals()[_name] = _create_function(_name, _doc, True, True)
del _name, _doc
__all__ += _functions.keys()
__all__ += _math_functions.keys()
__all__ += _binary_math_functions.keys()
__all__.sort()


Expand Down
29 changes: 29 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,35 @@ def test_aggregator(self):
self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0])
self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0])

def test_math_functions(self):
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
from pyspark.sql import functions
import math

def get_values(l):
return [j[0] for j in l]

def assert_close(a, b):
c = get_values(b)
diff = [abs(v - c[k]) < 1e-6 for k, v in enumerate(a)]
return sum(diff) == len(a)
assert_close([math.cos(i) for i in range(10)],
df.select(functions.cos(df.a)).collect())
assert_close([math.cos(i) for i in range(10)],
df.select(functions.cos("a")).collect())
assert_close([math.sin(i) for i in range(10)],
df.select(functions.sin(df.a)).collect())
assert_close([math.sin(i) for i in range(10)],
df.select(functions.sin(df['a'])).collect())
assert_close([math.pow(i, 2 * i) for i in range(10)],
df.select(functions.pow(df.a, df.b)).collect())
assert_close([math.pow(i, 2) for i in range(10)],
df.select(functions.pow(df.a, 2)).collect())
assert_close([math.pow(i, 2) for i in range(10)],
df.select(functions.pow(df.a, 2.0)).collect())
assert_close([math.hypot(i, 2 * i) for i in range(10)],
df.select(functions.hypot(df.a, df.b)).collect())

def test_save_and_load(self):
df = self.df
tmpPath = tempfile.mkdtemp()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,6 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
}
}

case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER")

case class Hypot(
left: Expression,
right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT")

case class Atan2(
left: Expression,
right: Expression) extends BinaryMathExpression(math.atan2, "ATAN2") {
Expand All @@ -91,3 +85,9 @@ case class Atan2(
}
}
}

case class Hypot(
left: Expression,
right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT")

case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER")
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,16 @@ import org.apache.spark.sql.types._
* input format, therefore these functions extend `ExpectsInputTypes`.
* @param name The short name of the function
*/
abstract class MathematicalExpression(name: String)
abstract class MathematicalExpression(f: Double => Double, name: String)
extends UnaryExpression with Serializable with ExpectsInputTypes {
self: Product =>
type EvaluatedType = Any

override def expectedChildTypes: Seq[DataType] = Seq(DoubleType)
override def dataType: DataType = DoubleType
override def foldable: Boolean = child.foldable
override def nullable: Boolean = true
override def toString: String = s"$name($child)"
}

/**
* A unary expression specifically for math functions that take a `Double` as input and return
* a `Double`.
* @param f The math function.
* @param name The short name of the function
*/
abstract class MathematicalExpressionForDouble(f: Double => Double, name: String)
extends MathematicalExpression(name) { self: Product =>

override def expectedChildTypes: Seq[DataType] = Seq(DoubleType)

override def eval(input: Row): Any = {
val evalE = child.eval(input)
Expand All @@ -58,111 +47,46 @@ abstract class MathematicalExpressionForDouble(f: Double => Double, name: String
}
}

/**
* A unary expression specifically for math functions that take an `Int` as input and return
* an `Int`.
* @param f The math function.
* @param name The short name of the function
*/
abstract class MathematicalExpressionForInt(f: Int => Int, name: String)
extends MathematicalExpression(name) { self: Product =>
case class Acos(child: Expression) extends MathematicalExpression(math.acos, "ACOS")

override def dataType: DataType = IntegerType
override def expectedChildTypes: Seq[DataType] = Seq(IntegerType)
case class Asin(child: Expression) extends MathematicalExpression(math.asin, "ASIN")

override def eval(input: Row): Any = {
val evalE = child.eval(input)
if (evalE == null) null else f(evalE.asInstanceOf[Int])
}
}
case class Atan(child: Expression) extends MathematicalExpression(math.atan, "ATAN")

/**
* A unary expression specifically for math functions that take a `Float` as input and return
* a `Float`.
* @param f The math function.
* @param name The short name of the function
*/
abstract class MathematicalExpressionForFloat(f: Float => Float, name: String)
extends MathematicalExpression(name) { self: Product =>
case class Cbrt(child: Expression) extends MathematicalExpression(math.cbrt, "CBRT")

override def dataType: DataType = FloatType
override def expectedChildTypes: Seq[DataType] = Seq(FloatType)
case class Ceil(child: Expression) extends MathematicalExpression(math.ceil, "CEIL")

override def eval(input: Row): Any = {
val evalE = child.eval(input)
if (evalE == null) {
null
} else {
val result = f(evalE.asInstanceOf[Float])
if (result.isNaN) null else result
}
}
}

/**
* A unary expression specifically for math functions that take a `Long` as input and return
* a `Long`.
* @param f The math function.
* @param name The short name of the function
*/
abstract class MathematicalExpressionForLong(f: Long => Long, name: String)
extends MathematicalExpression(name) { self: Product =>

override def dataType: DataType = LongType
override def expectedChildTypes: Seq[DataType] = Seq(LongType)

override def eval(input: Row): Any = {
val evalE = child.eval(input)
if (evalE == null) null else f(evalE.asInstanceOf[Long])
}
}

case class Sin(child: Expression) extends MathematicalExpressionForDouble(math.sin, "SIN")

case class Asin(child: Expression) extends MathematicalExpressionForDouble(math.asin, "ASIN")

case class Sinh(child: Expression) extends MathematicalExpressionForDouble(math.sinh, "SINH")

case class Cos(child: Expression) extends MathematicalExpressionForDouble(math.cos, "COS")
case class Cos(child: Expression) extends MathematicalExpression(math.cos, "COS")

case class Acos(child: Expression) extends MathematicalExpressionForDouble(math.acos, "ACOS")
case class Cosh(child: Expression) extends MathematicalExpression(math.cosh, "COSH")

case class Cosh(child: Expression) extends MathematicalExpressionForDouble(math.cosh, "COSH")
case class Exp(child: Expression) extends MathematicalExpression(math.exp, "EXP")

case class Tan(child: Expression) extends MathematicalExpressionForDouble(math.tan, "TAN")
case class Expm1(child: Expression) extends MathematicalExpression(math.expm1, "EXPM1")

case class Atan(child: Expression) extends MathematicalExpressionForDouble(math.atan, "ATAN")
case class Floor(child: Expression) extends MathematicalExpression(math.floor, "FLOOR")

case class Tanh(child: Expression) extends MathematicalExpressionForDouble(math.tanh, "TANH")
case class Log(child: Expression) extends MathematicalExpression(math.log, "LOG")

case class Ceil(child: Expression) extends MathematicalExpressionForDouble(math.ceil, "CEIL")
case class Log10(child: Expression) extends MathematicalExpression(math.log10, "LOG10")

case class Floor(child: Expression) extends MathematicalExpressionForDouble(math.floor, "FLOOR")
case class Log1p(child: Expression) extends MathematicalExpression(math.log1p, "LOG1P")

case class Rint(child: Expression) extends MathematicalExpressionForDouble(math.rint, "ROUND")
case class Rint(child: Expression) extends MathematicalExpression(math.rint, "ROUND")

case class Cbrt(child: Expression) extends MathematicalExpressionForDouble(math.cbrt, "CBRT")
case class Signum(child: Expression) extends MathematicalExpression(math.signum, "SIGNUM")

case class Signum(child: Expression) extends MathematicalExpressionForDouble(math.signum, "SIGNUM")
case class Sin(child: Expression) extends MathematicalExpression(math.sin, "SIN")

case class ISignum(child: Expression) extends MathematicalExpressionForInt(math.signum, "ISIGNUM")
case class Sinh(child: Expression) extends MathematicalExpression(math.sinh, "SINH")

case class FSignum(child: Expression) extends MathematicalExpressionForFloat(math.signum, "FSIGNUM")
case class Tan(child: Expression) extends MathematicalExpression(math.tan, "TAN")

case class LSignum(child: Expression) extends MathematicalExpressionForLong(math.signum, "LSIGNUM")
case class Tanh(child: Expression) extends MathematicalExpression(math.tanh, "TANH")

case class ToDegrees(child: Expression)
extends MathematicalExpressionForDouble(math.toDegrees, "DEGREES")
extends MathematicalExpression(math.toDegrees, "DEGREES")

case class ToRadians(child: Expression)
extends MathematicalExpressionForDouble(math.toRadians, "RADIANS")

case class Log(child: Expression) extends MathematicalExpressionForDouble(math.log, "LOG")

case class Log10(child: Expression) extends MathematicalExpressionForDouble(math.log10, "LOG10")

case class Log1p(child: Expression) extends MathematicalExpressionForDouble(math.log1p, "LOG1P")

case class Exp(child: Expression) extends MathematicalExpressionForDouble(math.exp, "EXP")

case class Expm1(child: Expression) extends MathematicalExpressionForDouble(math.expm1, "EXPM1")
extends MathematicalExpression(math.toRadians, "RADIANS")
Original file line number Diff line number Diff line change
Expand Up @@ -1253,18 +1253,6 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
unaryMathFunctionEvaluation[Double](Signum, math.signum)
}

test("isignum") {
unaryMathFunctionEvaluation[Int](ISignum, math.signum, (-5 to 5))
}

test("fsignum") {
unaryMathFunctionEvaluation[Float](FSignum, math.signum, (-5 to 5).map(_.toFloat))
}

test("lsignum") {
unaryMathFunctionEvaluation[Long](LSignum, math.signum, (5 to 5).map(_.toLong))
}

test("log") {
unaryMathFunctionEvaluation(Log, math.log, (0 to 20).map(_ * 0.1))
unaryMathFunctionEvaluation(Log, math.log, (-5 to -1).map(_ * 0.1), true)
Expand Down
Loading