Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
1df379b
Implement scala/python octet_length
yoda-mon Sep 10, 2021
6df4668
Implement R octet_length
yoda-mon Sep 10, 2021
eb2456b
Implement bit-length functions
yoda-mon Sep 13, 2021
3e149fa
Rearange the order of octet_lengh
yoda-mon Sep 13, 2021
ff2f031
Disable scalastyle to use non-ascii charactor on tests
yoda-mon Sep 13, 2021
5b9c98b
Insert newline to python docstrings to keep pylint rules
yoda-mon Sep 13, 2021
848f224
add space after comma to pass lint-R
yoda-mon Sep 14, 2021
4edc918
Implement scala/python octet_length
yoda-mon Sep 10, 2021
5370a51
Implement R octet_length
yoda-mon Sep 10, 2021
72c9398
Implement bit-length functions
yoda-mon Sep 13, 2021
761f8b6
Rearange the order of octet_lengh
yoda-mon Sep 13, 2021
f7ace95
Disable scalastyle to use non-ascii charactor on tests
yoda-mon Sep 13, 2021
2ed29f5
Insert newline to python docstrings to keep pylint rules
yoda-mon Sep 13, 2021
0858001
add space after comma to pass lint-R
yoda-mon Sep 14, 2021
c0ffdce
Merge branch 'add-bit-octet-length' of https://github.com/yoda-mon/sp…
yoda-mon Sep 14, 2021
327d7d5
Delete unnecessary line in the scala test.
yoda-mon Sep 14, 2021
387a92f
Add a short description of the test
yoda-mon Sep 14, 2021
70a1217
Add a short description to docstrings
yoda-mon Sep 14, 2021
3a07337
Move import to the top of the file
yoda-mon Sep 14, 2021
545fbe2
differentiate octet_length and bit_length test name
yoda-mon Sep 14, 2021
afa1700
Formatting decstrings
yoda-mon Sep 15, 2021
127da47
differentiate test comments
yoda-mon Sep 15, 2021
41656f4
Add bit/octet length to the api document
yoda-mon Sep 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions R/pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ exportMethods("%<=>%",
"base64",
"between",
"bin",
"bit_length",
"bitwise_not",
"bitwiseNOT",
"bround",
Expand Down Expand Up @@ -364,6 +365,7 @@ exportMethods("%<=>%",
"not",
"nth_value",
"ntile",
"octet_length",
"otherwise",
"over",
"overlay",
Expand Down
26 changes: 26 additions & 0 deletions R/pkg/R/functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,19 @@ setMethod("bin",
column(jc)
})

#' @details
#' \code{bit_length}: Calculates the bit length for the specified string column.
#'
#' @rdname column_string_functions
#' @aliases bit_length bit_length,Column-method
#' @note length since 3.3.0
setMethod("bit_length",
signature(x = "Column"),
function(x) {
jc <- callJStatic("org.apache.spark.sql.functions", "bit_length", x@jc)
column(jc)
})

#' @details
#' \code{bitwise_not}: Computes bitwise NOT.
#'
Expand Down Expand Up @@ -1569,6 +1582,19 @@ setMethod("negate",
column(jc)
})

#' @details
#' \code{octet_length}: Calculates the byte length for the specified string column.
#'
#' @rdname column_string_functions
#' @aliases octet_length octet_length,Column-method
#' @note length since 3.3.0
setMethod("octet_length",
signature(x = "Column"),
function(x) {
jc <- callJStatic("org.apache.spark.sql.functions", "octet_length", x@jc)
column(jc)
})

#' @details
#' \code{overlay}: Overlay the specified portion of \code{x} with \code{replace},
#' starting from byte position \code{pos} of \code{src} and proceeding for
Expand Down
8 changes: 8 additions & 0 deletions R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,10 @@ setGeneric("base64", function(x) { standardGeneric("base64") })
#' @name NULL
setGeneric("bin", function(x) { standardGeneric("bin") })

#' @rdname column_string_functions
#' @name NULL
setGeneric("bit_length", function(x, ...) { standardGeneric("bit_length") })

#' @rdname column_nonaggregate_functions
#' @name NULL
setGeneric("bitwise_not", function(x) { standardGeneric("bitwise_not") })
Expand Down Expand Up @@ -1230,6 +1234,10 @@ setGeneric("ntile", function(x) { standardGeneric("ntile") })
#' @name NULL
setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") })

#' @rdname column_string_functions
#' @name NULL
setGeneric("octet_length", function(x, ...) { standardGeneric("octet_length") })

#' @rdname column_string_functions
#' @name NULL
setGeneric("overlay", function(x, replace, pos, ...) { standardGeneric("overlay") })
Expand Down
11 changes: 11 additions & 0 deletions R/pkg/tests/fulltests/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -1988,6 +1988,17 @@ test_that("string operators", {
collect(select(df5, repeat_string(df5$a, -1)))[1, 1],
""
)

l6 <- list(list("cat"), list("\ud83d\udc08"))
df6 <- createDataFrame(l6)
expect_equal(
collect(select(df6, octet_length(df6$"_1")))[, 1],
c(3, 4)
)
expect_equal(
collect(select(df6, bit_length(df6$"_1")))[, 1],
c(24, 32)
)
})

test_that("date functions on a DataFrame", {
Expand Down
2 changes: 2 additions & 0 deletions python/docs/source/reference/pyspark.sql.rst
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ Functions
avg
base64
bin
bit_length
bitwise_not
bitwiseNOT
broadcast
Expand Down Expand Up @@ -483,6 +484,7 @@ Functions
next_day
nth_value
ntile
octet_length
overlay
pandas_udf
percent_rank
Expand Down
52 changes: 52 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3098,6 +3098,58 @@ def length(col):
return Column(sc._jvm.functions.length(_to_java_column(col)))


def octet_length(col):
"""
Calculates the byte length for the specified string column.

.. versionadded:: 3.3.0

Parameters
----------
col : :class:`~pyspark.sql.Column` or str
Source column or strings

Returns
-------
:class:`~pyspark.sql.Column`
Byte length of the col

Examples
-------
>>> from pyspark.sql.functions import octet_length
>>> spark.createDataFrame([('cat',), ( '\U0001F408',)], ['cat']) \
.select(octet_length('cat')).collect()
[Row(octet_length(cat)=3), Row(octet_length(cat)=4)]
"""
return _invoke_function_over_column("octet_length", col)


def bit_length(col):
"""
Calculates the bit length for the specified string column.

.. versionadded:: 3.3.0

Parameters
----------
col : :class:`~pyspark.sql.Column` or str
Source column or strings

Returns
-------
:class:`~pyspark.sql.Column`
Bit length of the col

Examples
-------
>>> from pyspark.sql.functions import bit_length
>>> spark.createDataFrame([('cat',), ( '\U0001F408',)], ['cat']) \
.select(bit_length('cat')).collect()
[Row(bit_length(cat)=24), Row(bit_length(cat)=32)]
"""
return _invoke_function_over_column("bit_length", col)


def translate(srcCol, matching, replace):
"""A function translate any character in the `srcCol` by a character in `matching`.
The characters in `replace` is corresponding to the characters in `matching`.
Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/sql/functions.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ def bin(col: ColumnOrName) -> Column: ...
def hex(col: ColumnOrName) -> Column: ...
def unhex(col: ColumnOrName) -> Column: ...
def length(col: ColumnOrName) -> Column: ...
def octet_length(col: ColumnOrName) -> Column: ...
def bit_length(col: ColumnOrName) -> Column: ...
def translate(srcCol: ColumnOrName, matching: str, replace: str) -> Column: ...
def map_from_arrays(col1: ColumnOrName, col2: ColumnOrName) -> Column: ...
def create_map(*cols: ColumnOrName) -> Column: ...
Expand Down
14 changes: 13 additions & 1 deletion python/pyspark/sql/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pyspark.sql import Row, Window
from pyspark.sql.functions import udf, input_file_name, col, percentile_approx, \
lit, assert_true, sum_distinct, sumDistinct, shiftleft, shiftLeft, shiftRight, \
shiftright, shiftrightunsigned, shiftRightUnsigned
shiftright, shiftrightunsigned, shiftRightUnsigned, octet_length, bit_length
from pyspark.testing.sqlutils import ReusedSQLTestCase


Expand Down Expand Up @@ -197,6 +197,18 @@ def test_string_functions(self):
df.select(getattr(functions, name)("name")).first()[0],
df.select(getattr(functions, name)(col("name"))).first()[0])

def test_octet_length_function(self):
# SPARK-36751: add octet length api for python
df = self.spark.createDataFrame([('cat',), ('\U0001F408',)], ['cat'])
actual = df.select(octet_length('cat')).collect()
self.assertEqual([Row(3), Row(4)], actual)

def test_bit_length_function(self):
# SPARK-36751: add bit length api for python
df = self.spark.createDataFrame([('cat',), ('\U0001F408',)], ['cat'])
actual = df.select(bit_length('cat')).collect()
self.assertEqual([Row(24), Row(32)], actual)

def test_array_contains_function(self):
from pyspark.sql.functions import array_contains

Expand Down
16 changes: 16 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2541,6 +2541,14 @@ object functions {
*/
def base64(e: Column): Column = withExpr { Base64(e.expr) }

/**
* Calculates the bit length for the specified string column.
*
* @group string_funcs
* @since 3.3.0
*/
def bit_length(e: Column): Column = withExpr { BitLength(e.expr) }

/**
* Concatenates multiple input string columns together into a single string column,
* using the given separator.
Expand Down Expand Up @@ -2706,6 +2714,14 @@ object functions {
StringTrimLeft(e.expr, Literal(trimString))
}

/**
* Calculates the byte length for the specified string column.
*
* @group string_funcs
* @since 3.3.0
*/
def octet_length(e: Column): Column = withExpr { OctetLength(e.expr) }

/**
* Extract a specific group matched by a Java regex, from the specified string column.
* If the regex did not match, or the specified group did not match, an empty string is returned.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,58 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession {
)
}

test("SPARK-36751: add octet length api for scala") {
// scalastyle:off
// non ascii characters are not allowed in the code, so we disable the scalastyle here.
val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123, 2.0f, 3.015, "\ud83d\udc08"))
.toDF("a", "b", "c", "d", "e", "f")
// string and binary input
checkAnswer(
df.select(octet_length($"a"), octet_length($"b")),
Row(3, 4))
// string and binary input
checkAnswer(
df.selectExpr("octet_length(a)", "octet_length(b)"),
Row(3, 4))
// integer, float and double input
checkAnswer(
df.selectExpr("octet_length(c)", "octet_length(d)", "octet_length(e)"),
Row(3, 3, 5)
)
// multi-byte character input
checkAnswer(
df.selectExpr("octet_length(f)"),
Row(4)
)
// scalastyle:on
}

test("SPARK-36751: add bit length api for scala") {
// scalastyle:off
// non ascii characters are not allowed in the code, so we disable the scalastyle here.
val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123, 2.0f, 3.015, "\ud83d\udc08"))
.toDF("a", "b", "c", "d", "e", "f")
// string and binary input
checkAnswer(
df.select(bit_length($"a"), bit_length($"b")),
Row(24, 32))
// string and binary input
checkAnswer(
df.selectExpr("bit_length(a)", "bit_length(b)"),
Row(24, 32))
// integer, float and double input
checkAnswer(
df.selectExpr("bit_length(c)", "bit_length(d)", "bit_length(e)"),
Row(24, 24, 40)
)
// multi-byte character input
checkAnswer(
df.selectExpr("bit_length(f)"),
Row(32)
)
// scalastyle:on
}

test("initcap function") {
val df = Seq(("ab", "a B", "sParK")).toDF("x", "y", "z")
checkAnswer(
Expand Down