Skip to content
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
__all__ = ['countDistinct', 'approxCountDistinct', 'udf']


def _create_function(name, doc=""):
def _create_function(name, doc="", is_math=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about changing is_math to "jvm_class" ?

and then remove _function_obj, and just pass sc._jvm.functions or sc._jvm.mathfunctions in.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

u can now remove is_math

""" Create a function for aggregator by name"""
def _(col):
sc = SparkContext._active_spark_context
Expand All @@ -54,7 +54,7 @@ def _(col):
'upper': 'Converts a string expression to upper case.',
'lower': 'Converts a string expression to upper case.',
'sqrt': 'Computes the square root of the specified float value.',
'abs': 'Computes the absolutle value.',
'abs': 'Computes the absolute value.',

'max': 'Aggregate function: returns the maximum value of the expression in a group.',
'min': 'Aggregate function: returns the minimum value of the expression in a group.',
Expand Down
101 changes: 101 additions & 0 deletions python/pyspark/sql/mathfunctions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
A collection of builtin math functions
"""

from pyspark import SparkContext
from pyspark.sql.dataframe import Column

__all__ = []


def _create_unary_mathfunction(name, doc=""):
""" Create a unary mathfunction by name"""
def _(col):
sc = SparkContext._active_spark_context
jc = getattr(sc._jvm.mathfunctions, name)(col._jc if isinstance(col, Column) else col)
return Column(jc)
_.__name__ = name
_.__doc__ = doc
return _


def _create_binary_mathfunction(name, doc=""):
""" Create a binary mathfunction by name"""
def _(col1, col2):
sc = SparkContext._active_spark_context
# users might write ints for simplicity. This would throw an error on the JVM side.
if type(col1) is int:
col1 = col1 * 1.0
if type(col2) is int:
col2 = col2 * 1.0
jc = getattr(sc._jvm.mathfunctions, name)(col1._jc if isinstance(col1, Column) else col1,
col2._jc if isinstance(col2, Column) else col2)
return Column(jc)
_.__name__ = name
_.__doc__ = doc
return _


# math functions are found under another object therefore, they need to be handled separately
_mathfunctions = {
'acos': 'Computes the cosine inverse of the given value; the returned angle is in the range' +
'0.0 through pi.',
'asin': 'Computes the sine inverse of the given value; the returned angle is in the range' +
'-pi/2 through pi/2.',
'atan': 'Computes the tangent inverse of the given value.',
'cbrt': 'Computes the cube-root of the given value.',
'ceil': 'Computes the ceiling of the given value.',
'cos': 'Computes the cosine of the given value.',
'cosh': 'Computes the hyperbolic cosine of the given value.',
'exp': 'Computes the exponential of the given value.',
'expm1': 'Computes the exponential of the given value minus one.',
'floor': 'Computes the floor of the given value.',
'log': 'Computes the natural logarithm of the given value.',
'log10': 'Computes the logarithm of the given value in Base 10.',
'log1p': 'Computes the natural logarithm of the given value plus one.',
'rint': 'Returns the double value that is closest in value to the argument and' +
' is equal to a mathematical integer.',
'signum': 'Computes the signum of the given value.',
'sin': 'Computes the sine of the given value.',
'sinh': 'Computes the hyperbolic sine of the given value.',
'tan': 'Computes the tangent of the given value.',
'tanh': 'Computes the hyperbolic tangent of the given value.',
'toDeg': 'Converts an angle measured in radians to an approximately equivalent angle ' +
'measured in degrees.',
'toRad': 'Converts an angle measured in degrees to an approximately equivalent angle ' +
'measured in radians.'
}

# math functions that take two arguments as input
_binary_mathfunctions = {
'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' +
'polar coordinates (r, theta).',
'hypot': 'Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow.',
'pow': 'Returns the value of the first argument raised to the power of the second argument.'
}

for _name, _doc in _mathfunctions.items():
globals()[_name] = _create_unary_mathfunction(_name, _doc)
for _name, _doc in _binary_mathfunctions.items():
globals()[_name] = _create_binary_mathfunction(_name, _doc)
del _name, _doc
__all__ += _mathfunctions.keys()
__all__ += _binary_mathfunctions.keys()
__all__.sort()
29 changes: 29 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,35 @@ def test_aggregator(self):
self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0])
self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0])

def test_math_functions(self):
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
from pyspark.sql import mathfunctions as functions
import math

def get_values(l):
return [j[0] for j in l]

def assert_close(a, b):
c = get_values(b)
diff = [abs(v - c[k]) < 1e-6 for k, v in enumerate(a)]
return sum(diff) == len(a)
assert_close([math.cos(i) for i in range(10)],
df.select(functions.cos(df.a)).collect())
assert_close([math.cos(i) for i in range(10)],
df.select(functions.cos("a")).collect())
assert_close([math.sin(i) for i in range(10)],
df.select(functions.sin(df.a)).collect())
assert_close([math.sin(i) for i in range(10)],
df.select(functions.sin(df['a'])).collect())
assert_close([math.pow(i, 2 * i) for i in range(10)],
df.select(functions.pow(df.a, df.b)).collect())
assert_close([math.pow(i, 2) for i in range(10)],
df.select(functions.pow(df.a, 2)).collect())
assert_close([math.pow(i, 2) for i in range(10)],
df.select(functions.pow(df.a, 2.0)).collect())
assert_close([math.hypot(i, 2 * i) for i in range(10)],
df.select(functions.hypot(df.a, df.b)).collect())

def test_save_and_load(self):
df = self.df
tmpPath = tempfile.mkdtemp()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,6 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
}
}

case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER")

case class Hypot(
left: Expression,
right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT")

case class Atan2(
left: Expression,
right: Expression) extends BinaryMathExpression(math.atan2, "ATAN2") {
Expand All @@ -91,3 +85,9 @@ case class Atan2(
}
}
}

case class Hypot(
left: Expression,
right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT")

case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER")
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,16 @@ import org.apache.spark.sql.types._
* input format, therefore these functions extend `ExpectsInputTypes`.
* @param name The short name of the function
*/
abstract class MathematicalExpression(name: String)
abstract class MathematicalExpression(f: Double => Double, name: String)
extends UnaryExpression with Serializable with ExpectsInputTypes {
self: Product =>
type EvaluatedType = Any

override def expectedChildTypes: Seq[DataType] = Seq(DoubleType)
override def dataType: DataType = DoubleType
override def foldable: Boolean = child.foldable
override def nullable: Boolean = true
override def toString: String = s"$name($child)"
}

/**
* A unary expression specifically for math functions that take a `Double` as input and return
* a `Double`.
* @param f The math function.
* @param name The short name of the function
*/
abstract class MathematicalExpressionForDouble(f: Double => Double, name: String)
extends MathematicalExpression(name) { self: Product =>

override def expectedChildTypes: Seq[DataType] = Seq(DoubleType)

override def eval(input: Row): Any = {
val evalE = child.eval(input)
Expand All @@ -58,111 +47,46 @@ abstract class MathematicalExpressionForDouble(f: Double => Double, name: String
}
}

/**
* A unary expression specifically for math functions that take an `Int` as input and return
* an `Int`.
* @param f The math function.
* @param name The short name of the function
*/
abstract class MathematicalExpressionForInt(f: Int => Int, name: String)
extends MathematicalExpression(name) { self: Product =>
case class Acos(child: Expression) extends MathematicalExpression(math.acos, "ACOS")

override def dataType: DataType = IntegerType
override def expectedChildTypes: Seq[DataType] = Seq(IntegerType)
case class Asin(child: Expression) extends MathematicalExpression(math.asin, "ASIN")

override def eval(input: Row): Any = {
val evalE = child.eval(input)
if (evalE == null) null else f(evalE.asInstanceOf[Int])
}
}
case class Atan(child: Expression) extends MathematicalExpression(math.atan, "ATAN")

/**
* A unary expression specifically for math functions that take a `Float` as input and return
* a `Float`.
* @param f The math function.
* @param name The short name of the function
*/
abstract class MathematicalExpressionForFloat(f: Float => Float, name: String)
extends MathematicalExpression(name) { self: Product =>
case class Cbrt(child: Expression) extends MathematicalExpression(math.cbrt, "CBRT")

override def dataType: DataType = FloatType
override def expectedChildTypes: Seq[DataType] = Seq(FloatType)
case class Ceil(child: Expression) extends MathematicalExpression(math.ceil, "CEIL")

override def eval(input: Row): Any = {
val evalE = child.eval(input)
if (evalE == null) {
null
} else {
val result = f(evalE.asInstanceOf[Float])
if (result.isNaN) null else result
}
}
}

/**
* A unary expression specifically for math functions that take a `Long` as input and return
* a `Long`.
* @param f The math function.
* @param name The short name of the function
*/
abstract class MathematicalExpressionForLong(f: Long => Long, name: String)
extends MathematicalExpression(name) { self: Product =>

override def dataType: DataType = LongType
override def expectedChildTypes: Seq[DataType] = Seq(LongType)

override def eval(input: Row): Any = {
val evalE = child.eval(input)
if (evalE == null) null else f(evalE.asInstanceOf[Long])
}
}

case class Sin(child: Expression) extends MathematicalExpressionForDouble(math.sin, "SIN")

case class Asin(child: Expression) extends MathematicalExpressionForDouble(math.asin, "ASIN")

case class Sinh(child: Expression) extends MathematicalExpressionForDouble(math.sinh, "SINH")

case class Cos(child: Expression) extends MathematicalExpressionForDouble(math.cos, "COS")
case class Cos(child: Expression) extends MathematicalExpression(math.cos, "COS")

case class Acos(child: Expression) extends MathematicalExpressionForDouble(math.acos, "ACOS")
case class Cosh(child: Expression) extends MathematicalExpression(math.cosh, "COSH")

case class Cosh(child: Expression) extends MathematicalExpressionForDouble(math.cosh, "COSH")
case class Exp(child: Expression) extends MathematicalExpression(math.exp, "EXP")

case class Tan(child: Expression) extends MathematicalExpressionForDouble(math.tan, "TAN")
case class Expm1(child: Expression) extends MathematicalExpression(math.expm1, "EXPM1")

case class Atan(child: Expression) extends MathematicalExpressionForDouble(math.atan, "ATAN")
case class Floor(child: Expression) extends MathematicalExpression(math.floor, "FLOOR")

case class Tanh(child: Expression) extends MathematicalExpressionForDouble(math.tanh, "TANH")
case class Log(child: Expression) extends MathematicalExpression(math.log, "LOG")

case class Ceil(child: Expression) extends MathematicalExpressionForDouble(math.ceil, "CEIL")
case class Log10(child: Expression) extends MathematicalExpression(math.log10, "LOG10")

case class Floor(child: Expression) extends MathematicalExpressionForDouble(math.floor, "FLOOR")
case class Log1p(child: Expression) extends MathematicalExpression(math.log1p, "LOG1P")

case class Rint(child: Expression) extends MathematicalExpressionForDouble(math.rint, "ROUND")
case class Rint(child: Expression) extends MathematicalExpression(math.rint, "ROUND")

case class Cbrt(child: Expression) extends MathematicalExpressionForDouble(math.cbrt, "CBRT")
case class Signum(child: Expression) extends MathematicalExpression(math.signum, "SIGNUM")

case class Signum(child: Expression) extends MathematicalExpressionForDouble(math.signum, "SIGNUM")
case class Sin(child: Expression) extends MathematicalExpression(math.sin, "SIN")

case class ISignum(child: Expression) extends MathematicalExpressionForInt(math.signum, "ISIGNUM")
case class Sinh(child: Expression) extends MathematicalExpression(math.sinh, "SINH")

case class FSignum(child: Expression) extends MathematicalExpressionForFloat(math.signum, "FSIGNUM")
case class Tan(child: Expression) extends MathematicalExpression(math.tan, "TAN")

case class LSignum(child: Expression) extends MathematicalExpressionForLong(math.signum, "LSIGNUM")
case class Tanh(child: Expression) extends MathematicalExpression(math.tanh, "TANH")

case class ToDegrees(child: Expression)
extends MathematicalExpressionForDouble(math.toDegrees, "DEGREES")
extends MathematicalExpression(math.toDegrees, "DEGREES")

case class ToRadians(child: Expression)
extends MathematicalExpressionForDouble(math.toRadians, "RADIANS")

case class Log(child: Expression) extends MathematicalExpressionForDouble(math.log, "LOG")

case class Log10(child: Expression) extends MathematicalExpressionForDouble(math.log10, "LOG10")

case class Log1p(child: Expression) extends MathematicalExpressionForDouble(math.log1p, "LOG1P")

case class Exp(child: Expression) extends MathematicalExpressionForDouble(math.exp, "EXP")

case class Expm1(child: Expression) extends MathematicalExpressionForDouble(math.expm1, "EXPM1")
extends MathematicalExpression(math.toRadians, "RADIANS")
Original file line number Diff line number Diff line change
Expand Up @@ -1253,18 +1253,6 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
unaryMathFunctionEvaluation[Double](Signum, math.signum)
}

test("isignum") {
unaryMathFunctionEvaluation[Int](ISignum, math.signum, (-5 to 5))
}

test("fsignum") {
unaryMathFunctionEvaluation[Float](FSignum, math.signum, (-5 to 5).map(_.toFloat))
}

test("lsignum") {
unaryMathFunctionEvaluation[Long](LSignum, math.signum, (5 to 5).map(_.toLong))
}

test("log") {
unaryMathFunctionEvaluation(Log, math.log, (0 to 20).map(_ * 0.1))
unaryMathFunctionEvaluation(Log, math.log, (-5 to -1).map(_ * 0.1), true)
Expand Down
Loading