Skip to content

Commit d3f7e0f

Browse files
committed
addressed comments and added tests
1 parent 7b7d7c4 commit d3f7e0f

File tree

2 files changed

+56
-3
lines changed

2 files changed

+56
-3
lines changed

python/pyspark/sql/functions.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,24 @@ def _function_obj(sc, is_math=False):
4040
return sc._jvm.mathfunctions
4141

4242

43-
def _create_function(name, doc="", is_math=False):
43+
def _create_function(name, doc="", is_math=False, binary=False):
4444
""" Create a function for aggregator by name"""
45-
def _(col):
45+
def _(col1, col2=None):
4646
sc = SparkContext._active_spark_context
47-
jc = getattr(_function_obj(sc, is_math), name)(col._jc if isinstance(col, Column) else col)
47+
if not binary:
48+
jc = getattr(_function_obj(sc, is_math), name)(col1._jc if isinstance(col1, Column)
49+
else col1)
50+
else:
51+
assert col2, "The second column for %s not provided!" % name
52+
# users might write ints for simplicity. This would throw an error on the JVM side.
53+
if type(col1) is int:
54+
col1 = col1 * 1.0
55+
if type(col2) is int:
56+
col2 = col2 * 1.0
57+
jc = getattr(_function_obj(sc, is_math), name)(col1._jc if isinstance(col1, Column)
58+
else col1,
59+
col2._jc if isinstance(col2, Column)
60+
else col2)
4861
return Column(jc)
4962
_.__name__ = name
5063
_.__doc__ = doc
@@ -107,14 +120,25 @@ def _(col):
107120
'measured in radians.'
108121
}
109122

123+
# math functions that take two arguments as input
124+
_binary_math_functions = {
125+
'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' +
126+
'polar coordinates (r, theta).',
127+
'hypot': 'Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow.',
128+
'pow': 'Returns the value of the first argument raised to the power of the second argument.'
129+
}
130+
110131

111132
for _name, _doc in _functions.items():
112133
globals()[_name] = _create_function(_name, _doc)
113134
for _name, _doc in _math_functions.items():
114135
globals()[_name] = _create_function(_name, _doc, True)
136+
for _name, _doc in _binary_math_functions.items():
137+
globals()[_name] = _create_function(_name, _doc, True, True)
115138
del _name, _doc
116139
__all__ += _functions.keys()
117140
__all__ += _math_functions.keys()
141+
__all__ += _binary_math_functions.keys()
118142
__all__.sort()
119143

120144

python/pyspark/sql/tests.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,35 @@ def test_aggregator(self):
387387
self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0])
388388
self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0])
389389

390+
def test_math_functions(self):
391+
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
392+
from pyspark.sql import functions
393+
import math
394+
395+
def get_values(l):
396+
return [j[0] for j in l]
397+
398+
def assert_close(a, b):
399+
c = get_values(b)
400+
diff = [abs(v - c[k]) < 1e-6 for k, v in enumerate(a)]
401+
return sum(diff) == len(a)
402+
assert_close([math.cos(i) for i in range(10)],
403+
df.select(functions.cos(df.a)).collect())
404+
assert_close([math.cos(i) for i in range(10)],
405+
df.select(functions.cos("a")).collect())
406+
assert_close([math.sin(i) for i in range(10)],
407+
df.select(functions.sin(df.a)).collect())
408+
assert_close([math.sin(i) for i in range(10)],
409+
df.select(functions.sin(df['a'])).collect())
410+
assert_close([math.pow(i, 2 * i) for i in range(10)],
411+
df.select(functions.pow(df.a, df.b)).collect())
412+
assert_close([math.pow(i, 2) for i in range(10)],
413+
df.select(functions.pow(df.a, 2)).collect())
414+
assert_close([math.pow(i, 2) for i in range(10)],
415+
df.select(functions.pow(df.a, 2.0)).collect())
416+
assert_close([math.hypot(i, 2 * i) for i in range(10)],
417+
df.select(functions.hypot(df.a, df.b)).collect())
418+
390419
def test_save_and_load(self):
391420
df = self.df
392421
tmpPath = tempfile.mkdtemp()

0 commit comments

Comments
 (0)