@@ -40,11 +40,24 @@ def _function_obj(sc, is_math=False):
4040 return sc ._jvm .mathfunctions
4141
4242
43- def _create_function (name , doc = "" , is_math = False ):
43+ def _create_function (name , doc = "" , is_math = False , binary = False ):
4444 """ Create a function for aggregator by name"""
45- def _ (col ):
45+ def _ (col1 , col2 = None ):
4646 sc = SparkContext ._active_spark_context
47- jc = getattr (_function_obj (sc , is_math ), name )(col ._jc if isinstance (col , Column ) else col )
47+ if not binary :
48+ jc = getattr (_function_obj (sc , is_math ), name )(col1 ._jc if isinstance (col1 , Column )
49+ else col1 )
50+ else :
51+ assert col2 , "The second column for %s not provided!" % name
52+ # users might write ints for simplicity. This would throw an error on the JVM side.
53+ if type (col1 ) is int :
54+ col1 = col1 * 1.0
55+ if type (col2 ) is int :
56+ col2 = col2 * 1.0
57+ jc = getattr (_function_obj (sc , is_math ), name )(col1 ._jc if isinstance (col1 , Column )
58+ else col1 ,
59+ col2 ._jc if isinstance (col2 , Column )
60+ else col2 )
4861 return Column (jc )
4962 _ .__name__ = name
5063 _ .__doc__ = doc
@@ -107,14 +120,25 @@ def _(col):
107120 'measured in radians.'
108121}
109122
123+ # math functions that take two arguments as input
124+ _binary_math_functions = {
125+ 'atan2' : 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' +
126+ 'polar coordinates (r, theta).' ,
127+ 'hypot' : 'Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow.' ,
128+ 'pow' : 'Returns the value of the first argument raised to the power of the second argument.'
129+ }
130+
110131
111132for _name , _doc in _functions .items ():
112133 globals ()[_name ] = _create_function (_name , _doc )
113134for _name , _doc in _math_functions .items ():
114135 globals ()[_name ] = _create_function (_name , _doc , True )
136+ for _name , _doc in _binary_math_functions .items ():
137+ globals ()[_name ] = _create_function (_name , _doc , True , True )
115138del _name , _doc
116139__all__ += _functions .keys ()
117140__all__ += _math_functions .keys ()
141+ __all__ += _binary_math_functions .keys ()
118142__all__ .sort ()
119143
120144
0 commit comments