diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 2929a00330c6..31238c689329 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -3401,13 +3401,21 @@ setMethod("collect_set", #' @details #' \code{split_string}: Splits string on regular expression. -#' Equivalent to \code{split} SQL function. +#' Equivalent to \code{split} SQL function. Optionally a +#' \code{limit} can be specified #' #' @rdname column_string_functions +#' @param limit determines the length of the returned array. +#' \itemize{ +#' \item \code{limit > 0}: length of the array will be at most \code{limit} +#' \item \code{limit <= 0}: the returned array can have any length +#' } +#' #' @aliases split_string split_string,Column-method #' @examples #' #' \dontrun{ +#' head(select(df, split_string(df$Class, "\\d", 2))) #' head(select(df, split_string(df$Sex, "a"))) #' head(select(df, split_string(df$Class, "\\d"))) #' # This is equivalent to the following SQL expression @@ -3415,8 +3423,9 @@ setMethod("collect_set", #' @note split_string 2.3.0 setMethod("split_string", signature(x = "Column", pattern = "character"), - function(x, pattern) { - jc <- callJStatic("org.apache.spark.sql.functions", "split", x@jc, pattern) + function(x, pattern, limit = -1) { + jc <- callJStatic("org.apache.spark.sql.functions", + "split", x@jc, pattern, as.integer(limit)) column(jc) }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index f6f1849787a2..a6c3a13302c9 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1242,7 +1242,7 @@ setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array") #' @rdname column_string_functions #' @name NULL -setGeneric("split_string", function(x, pattern) { standardGeneric("split_string") }) +setGeneric("split_string", function(x, pattern, ...) { standardGeneric("split_string") }) #' @rdname column_string_functions #' @name NULL diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index bff6e3512ee2..18c1e4c9663b 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1803,6 +1803,14 @@ test_that("string operators", { collect(select(df4, split_string(df4$a, "\\\\")))[1, 1], list(list("a.b@c.d 1", "b")) ) + expect_equal( + collect(select(df4, split_string(df4$a, "\\.", 2)))[1, 1], + list(list("a", "b@c.d 1\\b")) + ) + expect_equal( + collect(select(df4, split_string(df4$a, "b", 0)))[1, 1], + list(list("a.", "@c.d 1\\", "")) + ) l5 <- list(list(a = "abc")) df5 <- createDataFrame(l5) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index e91fc4391425..2bada7bbc4aa 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -952,6 +952,12 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { } public UTF8String[] split(UTF8String pattern, int limit) { + // Java String's split method supports "ignore empty string" behavior when the limit is 0 + // whereas other languages do not. To avoid this java specific behavior, we fall back to + // -1 when the limit is 0. + if (limit == 0) { + limit = -1; + } String[] splits = toString().split(pattern.toString(), limit); UTF8String[] res = new UTF8String[splits.length]; for (int i = 0; i < res.length; i++) { diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 42dda3048070..58de990cfc4e 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -394,12 +394,14 @@ public void substringSQL() { @Test public void split() { - assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), -1), - new UTF8String[]{fromString("ab"), fromString("def"), fromString("ghi")})); - assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), 2), - new UTF8String[]{fromString("ab"), fromString("def,ghi")})); - assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), 2), - new UTF8String[]{fromString("ab"), fromString("def,ghi")})); + UTF8String[] negativeAndZeroLimitCase = + new UTF8String[]{fromString("ab"), fromString("def"), fromString("ghi"), fromString("")}; + assertTrue(Arrays.equals(fromString("ab,def,ghi,").split(fromString(","), 0), + negativeAndZeroLimitCase)); + assertTrue(Arrays.equals(fromString("ab,def,ghi,").split(fromString(","), -1), + negativeAndZeroLimitCase)); + assertTrue(Arrays.equals(fromString("ab,def,ghi,").split(fromString(","), 2), + new UTF8String[]{fromString("ab"), fromString("def,ghi,")})); } @Test diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index d58d8d10e5cd..6343374dab0d 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1671,18 +1671,32 @@ def repeat(col, n): @since(1.5) @ignore_unicode_prefix -def split(str, pattern): +def split(str, pattern, limit=-1): """ - Splits str around pattern (pattern is a regular expression). + Splits str around matches of the given pattern. - .. note:: pattern is a string represent the regular expression. + :param str: a string expression to split + :param pattern: a string representing a regular expression. The regex string should be + a Java regular expression. + :param limit: an integer which controls the number of times `pattern` is applied. - >>> df = spark.createDataFrame([('ab12cd',)], ['s',]) - >>> df.select(split(df.s, '[0-9]+').alias('s')).collect() - [Row(s=[u'ab', u'cd'])] + * ``limit > 0``: The resulting array's length will not be more than `limit`, and the + resulting array's last entry will contain all input beyond the last + matched pattern. + * ``limit <= 0``: `pattern` will be applied as many times as possible, and the resulting + array can be of any size. + + .. versionchanged:: 3.0 + `split` now takes an optional `limit` field. If not provided, default limit value is -1. + + >>> df = spark.createDataFrame([('oneAtwoBthreeC',)], ['s',]) + >>> df.select(split(df.s, '[ABC]', 2).alias('s')).collect() + [Row(s=[u'one', u'twoBthreeC'])] + >>> df.select(split(df.s, '[ABC]', -1).alias('s')).collect() + [Row(s=[u'one', u'two', u'three', u''])] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.split(_to_java_column(str), pattern)) + return Column(sc._jvm.functions.split(_to_java_column(str), pattern, limit)) @ignore_unicode_prefix diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index bf0c35fe6101..4f5ea1e95f83 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -157,7 +157,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi arguments = """ Arguments: * str - a string expression - * regexp - a string expression. The pattern string should be a Java regular expression. + * regexp - a string expression. The regex string should be a Java regular expression. Since Spark 2.0, string literals (including regex patterns) are unescaped in our SQL parser. For example, to match "\abc", a regular expression for `regexp` can be @@ -229,33 +229,53 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress /** - * Splits str around pat (pattern is a regular expression). + * Splits str around matches of the given regex. */ @ExpressionDescription( - usage = "_FUNC_(str, regex) - Splits `str` around occurrences that match `regex`.", + usage = "_FUNC_(str, regex, limit) - Splits `str` around occurrences that match `regex`" + + " and returns an array with a length of at most `limit`", + arguments = """ + Arguments: + * str - a string expression to split. + * regex - a string representing a regular expression. The regex string should be a + Java regular expression. + * limit - an integer expression which controls the number of times the regex is applied. + * limit > 0: The resulting array's length will not be more than `limit`, + and the resulting array's last entry will contain all input + beyond the last matched regex. + * limit <= 0: `regex` will be applied as many times as possible, and + the resulting array can be of any size. + """, examples = """ Examples: > SELECT _FUNC_('oneAtwoBthreeC', '[ABC]'); ["one","two","three",""] + > SELECT _FUNC_('oneAtwoBthreeC', '[ABC]', -1); + ["one","two","three",""] + > SELECT _FUNC_('oneAtwoBthreeC', '[ABC]', 2); + ["one","twoBthreeC"] """) -case class StringSplit(str: Expression, pattern: Expression) - extends BinaryExpression with ImplicitCastInputTypes { +case class StringSplit(str: Expression, regex: Expression, limit: Expression) + extends TernaryExpression with ImplicitCastInputTypes { - override def left: Expression = str - override def right: Expression = pattern override def dataType: DataType = ArrayType(StringType) - override def inputTypes: Seq[DataType] = Seq(StringType, StringType) + override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) + override def children: Seq[Expression] = str :: regex :: limit :: Nil - override def nullSafeEval(string: Any, regex: Any): Any = { - val strings = string.asInstanceOf[UTF8String].split(regex.asInstanceOf[UTF8String], -1) + def this(exp: Expression, regex: Expression) = this(exp, regex, Literal(-1)); + + override def nullSafeEval(string: Any, regex: Any, limit: Any): Any = { + val strings = string.asInstanceOf[UTF8String].split( + regex.asInstanceOf[UTF8String], limit.asInstanceOf[Int]) new GenericArrayData(strings.asInstanceOf[Array[Any]]) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val arrayClass = classOf[GenericArrayData].getName - nullSafeCodeGen(ctx, ev, (str, pattern) => + nullSafeCodeGen(ctx, ev, (str, regex, limit) => { // Array in java is covariant, so we don't need to cast UTF8String[] to Object[]. - s"""${ev.value} = new $arrayClass($str.split($pattern, -1));""") + s"""${ev.value} = new $arrayClass($str.split($regex,$limit));""".stripMargin + }) } override def prettyName: String = "split" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala index d532dc4f7719..06fb73ad8392 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala @@ -225,11 +225,18 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val row3 = create_row("aa2bb3cc", null) checkEvaluation( - StringSplit(Literal("aa2bb3cc"), Literal("[1-9]+")), Seq("aa", "bb", "cc"), row1) + StringSplit(Literal("aa2bb3cc"), Literal("[1-9]+"), -1), Seq("aa", "bb", "cc"), row1) checkEvaluation( - StringSplit(s1, s2), Seq("aa", "bb", "cc"), row1) - checkEvaluation(StringSplit(s1, s2), null, row2) - checkEvaluation(StringSplit(s1, s2), null, row3) + StringSplit(Literal("aa2bb3cc"), Literal("[1-9]+"), 2), Seq("aa", "bb3cc"), row1) + // limit = 0 should behave just like limit = -1 + checkEvaluation( + StringSplit(Literal("aacbbcddc"), Literal("c"), 0), Seq("aa", "bb", "dd", ""), row1) + checkEvaluation( + StringSplit(Literal("aacbbcddc"), Literal("c"), -1), Seq("aa", "bb", "dd", ""), row1) + checkEvaluation( + StringSplit(s1, s2, -1), Seq("aa", "bb", "cc"), row1) + checkEvaluation(StringSplit(s1, s2, -1), null, row2) + checkEvaluation(StringSplit(s1, s2, -1), null, row3) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c9331883c479..575f00b8807c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2546,15 +2546,39 @@ object functions { def soundex(e: Column): Column = withExpr { SoundEx(e.expr) } /** - * Splits str around pattern (pattern is a regular expression). + * Splits str around matches of the given regex. * - * @note Pattern is a string representation of the regular expression. + * @param str a string expression to split + * @param regex a string representing a regular expression. The regex string should be + * a Java regular expression. * * @group string_funcs * @since 1.5.0 */ - def split(str: Column, pattern: String): Column = withExpr { - StringSplit(str.expr, lit(pattern).expr) + def split(str: Column, regex: String): Column = withExpr { + StringSplit(str.expr, Literal(regex), Literal(-1)) + } + + /** + * Splits str around matches of the given regex. + * + * @param str a string expression to split + * @param regex a string representing a regular expression. The regex string should be + * a Java regular expression. + * @param limit an integer expression which controls the number of times the regex is applied. + *