Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/docs/source/reference/pyspark.sql.rst
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ Functions
covar_samp
crc32
create_map
csc
cume_dist
current_date
current_timestamp
Expand Down Expand Up @@ -511,6 +512,7 @@ Functions
rtrim
schema_of_csv
schema_of_json
sec
second
sentences
sequence
Expand Down
58 changes: 58 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,8 @@ def product(col):

def acos(col):
"""
Computes inverse cosine of the input column.

.. versionadded:: 1.4.0

Returns
Expand All @@ -278,6 +280,8 @@ def acosh(col):

def asin(col):
"""
Computes inverse sine of the input column.

.. versionadded:: 1.3.0


Expand All @@ -304,6 +308,8 @@ def asinh(col):

def atan(col):
"""
Compute inverse tangent of the input column.

.. versionadded:: 1.4.0

Returns
Expand Down Expand Up @@ -345,6 +351,8 @@ def ceil(col):

def cos(col):
"""
Computes cosine of the input column.

.. versionadded:: 1.4.0

Parameters
Expand All @@ -362,6 +370,8 @@ def cos(col):

def cosh(col):
"""
Computes hyperbolic cosine of the input column.

.. versionadded:: 1.4.0

Parameters
Expand All @@ -379,6 +389,8 @@ def cosh(col):

def cot(col):
"""
Computes cotangent of the input column.

.. versionadded:: 3.3.0

Parameters
Expand All @@ -394,6 +406,25 @@ def cot(col):
return _invoke_function_over_column("cot", col)


def csc(col):
"""
Computes cosecant of the input column.

.. versionadded:: 3.3.0
Copy link
Member

@HyukjinKwon HyukjinKwon Sep 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we add a short description?

Copy link
Contributor Author

@yutoacts yutoacts Sep 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, just added a description to all trig functions that missed it.


Parameters
----------
col : :class:`~pyspark.sql.Column` or str
Angle in radians

Returns
-------
:class:`~pyspark.sql.Column`
Cosecant of the angle.
"""
return _invoke_function_over_column("csc", col)


@since(1.4)
def exp(col):
"""
Expand Down Expand Up @@ -451,6 +482,25 @@ def rint(col):
return _invoke_function_over_column("rint", col)


def sec(col):
"""
Computes secant of the input column.

.. versionadded:: 3.3.0

Parameters
----------
col : :class:`~pyspark.sql.Column` or str
Angle in radians

Returns
-------
:class:`~pyspark.sql.Column`
Secant of the angle.
"""
return _invoke_function_over_column("sec", col)


@since(1.4)
def signum(col):
"""
Expand All @@ -461,6 +511,8 @@ def signum(col):

def sin(col):
"""
Computes sine of the input column.

.. versionadded:: 1.4.0

Parameters
Expand All @@ -477,6 +529,8 @@ def sin(col):

def sinh(col):
"""
Computes hyperbolic sine of the input column.

.. versionadded:: 1.4.0

Parameters
Expand All @@ -495,6 +549,8 @@ def sinh(col):

def tan(col):
"""
Computes tangent of the input column.

.. versionadded:: 1.4.0

Parameters
Expand All @@ -512,6 +568,8 @@ def tan(col):

def tanh(col):
"""
Computes hyperbolic tangent of the input column.

.. versionadded:: 1.4.0

Parameters
Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/sql/functions.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def cos(col: ColumnOrName) -> Column: ...
def cosh(col: ColumnOrName) -> Column: ...
def cot(col: ColumnOrName) -> Column: ...
def count(col: ColumnOrName) -> Column: ...
def csc(col: ColumnOrName) -> Column: ...
def cume_dist() -> Column: ...
def degrees(col: ColumnOrName) -> Column: ...
def dense_rank() -> Column: ...
Expand Down Expand Up @@ -339,6 +340,7 @@ def rank() -> Column: ...
def rint(col: ColumnOrName) -> Column: ...
def row_number() -> Column: ...
def rtrim(col: ColumnOrName) -> Column: ...
def sec(col: ColumnOrName) -> Column: ...
def signum(col: ColumnOrName) -> Column: ...
def sin(col: ColumnOrName) -> Column: ...
def sinh(col: ColumnOrName) -> Column: ...
Expand Down
78 changes: 44 additions & 34 deletions python/pyspark/sql/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
import datetime
from itertools import chain
import re
import math

from py4j.protocol import Py4JJavaError
from pyspark.sql import Row, Window
from pyspark.sql import Row, Window, types
from pyspark.sql.functions import udf, input_file_name, col, percentile_approx, \
lit, assert_true, sum_distinct, sumDistinct, shiftleft, shiftLeft, shiftRight, \
shiftright, shiftrightunsigned, shiftRightUnsigned, octet_length, bit_length
from pyspark.testing.sqlutils import ReusedSQLTestCase
shiftright, shiftrightunsigned, shiftRightUnsigned, octet_length, bit_length, \
sec, csc, cot
from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils


class FunctionsTests(ReusedSQLTestCase):
Expand Down Expand Up @@ -109,38 +111,29 @@ def test_crosstab(self):
def test_math_functions(self):
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
from pyspark.sql import functions
import math

def get_values(l):
return [j[0] for j in l]

def assert_close(a, b):
c = get_values(b)
diff = [abs(v - c[k]) < 1e-6 for k, v in enumerate(a)]
return sum(diff) == len(a)

assert_close([math.cos(i) for i in range(10)],
df.select(functions.cos(df.a)).collect())
assert_close([math.cos(i) for i in range(10)],
df.select(functions.cos("a")).collect())
assert_close([math.sin(i) for i in range(10)],
df.select(functions.sin(df.a)).collect())
assert_close([math.sin(i) for i in range(10)],
df.select(functions.sin(df['a'])).collect())
assert_close([math.pow(i, 2 * i) for i in range(10)],
df.select(functions.pow(df.a, df.b)).collect())
assert_close([math.pow(i, 2) for i in range(10)],
df.select(functions.pow(df.a, 2)).collect())
assert_close([math.pow(i, 2) for i in range(10)],
df.select(functions.pow(df.a, 2.0)).collect())
assert_close([math.hypot(i, 2 * i) for i in range(10)],
df.select(functions.hypot(df.a, df.b)).collect())
assert_close([math.hypot(i, 2 * i) for i in range(10)],
df.select(functions.hypot("a", u"b")).collect())
assert_close([math.hypot(i, 2) for i in range(10)],
df.select(functions.hypot("a", 2)).collect())
assert_close([math.hypot(i, 2) for i in range(10)],
df.select(functions.hypot(df.a, 2)).collect())
SQLTestUtils.assert_close([math.cos(i) for i in range(10)],
df.select(functions.cos(df.a)).collect())
SQLTestUtils.assert_close([math.cos(i) for i in range(10)],
df.select(functions.cos("a")).collect())
SQLTestUtils.assert_close([math.sin(i) for i in range(10)],
df.select(functions.sin(df.a)).collect())
SQLTestUtils.assert_close([math.sin(i) for i in range(10)],
df.select(functions.sin(df['a'])).collect())
SQLTestUtils.assert_close([math.pow(i, 2 * i) for i in range(10)],
df.select(functions.pow(df.a, df.b)).collect())
SQLTestUtils.assert_close([math.pow(i, 2) for i in range(10)],
df.select(functions.pow(df.a, 2)).collect())
SQLTestUtils.assert_close([math.pow(i, 2) for i in range(10)],
df.select(functions.pow(df.a, 2.0)).collect())
SQLTestUtils.assert_close([math.hypot(i, 2 * i) for i in range(10)],
df.select(functions.hypot(df.a, df.b)).collect())
SQLTestUtils.assert_close([math.hypot(i, 2 * i) for i in range(10)],
df.select(functions.hypot("a", u"b")).collect())
SQLTestUtils.assert_close([math.hypot(i, 2) for i in range(10)],
df.select(functions.hypot("a", 2)).collect())
SQLTestUtils.assert_close([math.hypot(i, 2) for i in range(10)],
df.select(functions.hypot(df.a, 2)).collect())

def test_inverse_trig_functions(self):
from pyspark.sql import functions
Expand All @@ -157,6 +150,23 @@ def test_inverse_trig_functions(self):
for c in cols:
self.assertIn(f"{alias}(a)", repr(f(c)))

def test_reciprocal_trig_functions(self):
# SPARK-36683: Tests for reciprocal trig functions (SEC, CSC and COT)
lst = [0.0, math.pi / 6, math.pi / 4, math.pi / 3, math.pi / 2,
math.pi, 3 * math.pi / 2, 2 * math.pi]

df = self.spark.createDataFrame(lst, types.DoubleType())

def to_reciprocal_trig(func):
return [1.0 / func(i) if func(i) != 0 else math.inf for i in lst]

SQLTestUtils.assert_close(to_reciprocal_trig(math.cos),
df.select(sec(df.value)).collect())
SQLTestUtils.assert_close(to_reciprocal_trig(math.sin),
df.select(csc(df.value)).collect())
SQLTestUtils.assert_close(to_reciprocal_trig(math.tan),
df.select(cot(df.value)).collect())

def test_rand_functions(self):
df = self.df
from pyspark.sql import functions
Expand Down
8 changes: 8 additions & 0 deletions python/pyspark/testing/sqlutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#

import datetime
import math
import os
import shutil
import tempfile
Expand Down Expand Up @@ -243,6 +244,13 @@ def function(self, *functions):
for f in functions:
self.spark.sql("DROP FUNCTION IF EXISTS %s" % f)

@staticmethod
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved get_values and assert_close and combined them as a static method under SQLTestUtils as I don't see any function on the top-level statement.
Please let me know if it should be on the top level rather than a method.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah looks good.

def assert_close(a, b):
c = [j[0] for j in b]
diff = [abs(v - c[k]) < 1e-6 if math.isfinite(v) else v == c[k]
for k, v in enumerate(a)]
return sum(diff) == len(a)


class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils):
@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ object FunctionRegistry {
expression[Ceil]("ceil"),
expression[Ceil]("ceiling", true),
expression[Cos]("cos"),
expression[Sec]("sec"),
expression[Cosh]("cosh"),
expression[Conv]("conv"),
expression[ToDegrees]("degrees"),
Expand Down Expand Up @@ -392,6 +393,7 @@ object FunctionRegistry {
expression[Signum]("sign", true),
expression[Signum]("signum"),
expression[Sin]("sin"),
expression[Csc]("csc"),
expression[Sinh]("sinh"),
expression[StringToMap]("str_to_map"),
expression[Sqrt]("sqrt"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,29 @@ case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS") {
override protected def withNewChildInternal(newChild: Expression): Cos = copy(child = newChild)
}

@ExpressionDescription(
usage = """
_FUNC_(expr) - Returns the secant of `expr`, as if computed by `1/java.lang.Math.cos`.
""",
arguments = """
Arguments:
* expr - angle in radians
""",
examples = """
Examples:
> SELECT _FUNC_(0);
1.0
""",
since = "3.3.0",
group = "math_funcs")
case class Sec(child: Expression)
extends UnaryMathExpression((x: Double) => 1 / math.cos(x), "SEC") {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, c => s"${ev.value} = 1 / java.lang.Math.cos($c);")
}
override protected def withNewChildInternal(newChild: Expression): Sec = copy(child = newChild)
}

@ExpressionDescription(
usage = """
_FUNC_(expr) - Returns the hyperbolic cosine of `expr`, as if computed by
Expand Down Expand Up @@ -655,6 +678,29 @@ case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN") {
override protected def withNewChildInternal(newChild: Expression): Sin = copy(child = newChild)
}

@ExpressionDescription(
usage = """
_FUNC_(expr) - Returns the cosecant of `expr`, as if computed by `1/java.lang.Math.sin`.
""",
arguments = """
Arguments:
* expr - angle in radians
""",
examples = """
Examples:
> SELECT _FUNC_(1);
1.1883951057781212
""",
since = "3.3.0",
group = "math_funcs")
case class Csc(child: Expression)
extends UnaryMathExpression((x: Double) => 1 / math.sin(x), "CSC") {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, c => s"${ev.value} = 1 / java.lang.Math.sin($c);")
}
override protected def withNewChildInternal(newChild: Expression): Csc = copy(child = newChild)
}

@ExpressionDescription(
usage = """
_FUNC_(expr) - Returns hyperbolic sine of `expr`, as if computed by `java.lang.Math._FUNC_`.
Expand Down
Loading