Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -229,33 +229,59 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress


/**
* Splits str around pat (pattern is a regular expression).
* Splits str around pattern (pattern is a regular expression).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pattern? regex? we should use a consisntent word.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

going to switch to regex, makes more sense given that with the use of pattern we always have to define it as a regex

*/
@ExpressionDescription(
usage = "_FUNC_(str, regex) - Splits `str` around occurrences that match `regex`.",
usage = "_FUNC_(str, regex, limit) - Splits `str` around occurrences that match `regex`" +
" and returns an array of at most `limit`",
arguments = """
Arguments:
* str - a string expression to split.
* pattern - a string representing a regular expression. The pattern string should be a
Java regular expression.
* limit - an integer expression which controls the number of times the pattern is applied.

limit > 0:
The resulting array's length will not be more than `limit`, and the resulting array's
last entry will contain all input beyond the last matched pattern.

limit < 0:
`pattern` will be applied as many times as possible, and the resulting
array can be of any size.

limit = 0:
`pattern` will be applied as many times as possible, the resulting array can
be of any size, and trailing empty strings will be discarded.
""",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about this formatting?;


 function_desc | Extended Usage:
    Arguments:
      * str - a string expression to split.
      * pattern - a string representing a regular expression. The pattern string should be a
        Java regular expression.
      * limit - an integer expression which controls the number of times the pattern is applied.

        limit > 0: The resulting array's length will not be more than `limit`, and the resulting array's
                   last entry will contain all input beyond the last matched pattern.
        limit < 0: `pattern` will be applied as many times as possible, and the resulting
                   array can be of any size.
        limit = 0: `pattern` will be applied as many times as possible, the resulting array can
                   be of any size, and trailing empty strings will be discarded.
  
    Examples:
      > SELECT split('oneAtwoBthreeC', '[ABC]');
       ["one","two","three",""]
      > SELECT split('oneAtwoBthreeC', '[ABC]', 0);
       ["one","two","three"]
      > SELECT split('oneAtwoBthreeC', '[ABC]', 2);
       ["one","twoBthreeC"]

examples = """
Examples:
> SELECT _FUNC_('oneAtwoBthreeC', '[ABC]');
["one","two","three",""]
| > SELECT _FUNC_('oneAtwoBthreeC', '[ABC]', 0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

drop |

["one","two","three"]
| > SELECT _FUNC_('oneAtwoBthreeC', '[ABC]', 2);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

["one","twoBthreeC"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the netative case?

""")
case class StringSplit(str: Expression, pattern: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
case class StringSplit(str: Expression, pattern: Expression, limit: Expression)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still need to support 2 arguments. Please add a constructor def this(str: Expression, pattern: Expression).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For test coverage, better to add tests in string-functions.sql for the two cases: two arguments and three arguments.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maropu which tests use string-functions.sql? would like to add tests here but not sure how to explicitly kick off the test as there are no *Suites which use this file it seems.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^ ignore this! found it @maropu

extends TernaryExpression with ImplicitCastInputTypes {

override def left: Expression = str
override def right: Expression = pattern
override def dataType: DataType = ArrayType(StringType)
override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
override def children: Seq[Expression] = str :: pattern :: limit :: Nil

def this(exp: Expression, pattern: Expression) = this(exp, pattern, Literal(-1));

override def nullSafeEval(string: Any, regex: Any): Any = {
val strings = string.asInstanceOf[UTF8String].split(regex.asInstanceOf[UTF8String], -1)
override def nullSafeEval(string: Any, regex: Any, limit: Any): Any = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we still need to do some check on limit. According to Presto document, limit must be a positive number. -1 is only used when no limit parameter is given (default value).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@viirya the underlying implementation of this method is Java.lang.String, correct? This method does allow non-positive values for limit, not sure what Presto is using.

val strings = string.asInstanceOf[UTF8String].split(
regex.asInstanceOf[UTF8String], limit.asInstanceOf[Int])
new GenericArrayData(strings.asInstanceOf[Array[Any]])
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val arrayClass = classOf[GenericArrayData].getName
nullSafeCodeGen(ctx, ev, (str, pattern) =>
nullSafeCodeGen(ctx, ev, (str, pattern, limit) =>
// Array in java is covariant, so we don't need to cast UTF8String[] to Object[].
s"""${ev.value} = new $arrayClass($str.split($pattern, -1));""")
s"""${ev.value} = new $arrayClass($str.split($pattern, $limit));""")
}

override def prettyName: String = "split"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,11 +225,17 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val row3 = create_row("aa2bb3cc", null)

checkEvaluation(
StringSplit(Literal("aa2bb3cc"), Literal("[1-9]+")), Seq("aa", "bb", "cc"), row1)
StringSplit(Literal("aa2bb3cc"), Literal("[1-9]+"), -1), Seq("aa", "bb", "cc"), row1)
checkEvaluation(
StringSplit(s1, s2), Seq("aa", "bb", "cc"), row1)
checkEvaluation(StringSplit(s1, s2), null, row2)
checkEvaluation(StringSplit(s1, s2), null, row3)
StringSplit(Literal("aa2bb3cc"), Literal("[1-9]+"), 2), Seq("aa", "bb3cc"), row1)
checkEvaluation(
StringSplit(Literal("aacbbcddc"), Literal("c"), 0), Seq("aa", "bb", "dd"), row1)
checkEvaluation(
StringSplit(Literal("aacbbcddc"), Literal("c"), -1), Seq("aa", "bb", "dd", ""), row1)
checkEvaluation(
StringSplit(s1, s2, -1), Seq("aa", "bb", "cc"), row1)
checkEvaluation(StringSplit(s1, s2, -1), null, row2)
checkEvaluation(StringSplit(s1, s2, -1), null, row3)
}

}
22 changes: 21 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2554,7 +2554,27 @@ object functions {
* @since 1.5.0
*/
def split(str: Column, pattern: String): Column = withExpr {
StringSplit(str.expr, lit(pattern).expr)
StringSplit(str.expr, Literal(pattern), Literal(-1))
}

/**
* Splits str around pattern (pattern is a regular expression).
*
* The limit parameter controls the number of times the pattern is applied and therefore
* affects the length of the resulting array. If the limit n is greater than zero then the
* pattern will be applied at most n - 1 times, the array's length will be no greater than
* n, and the array's last entry will contain all input beyond the last matched delimiter.
* If n is non-positive then the pattern will be applied as many times as possible and the
* array can have any length. If n is zero then the pattern will be applied as many times as
* possible, the array can have any length, and trailing empty strings will be discarded.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you copy SQL's doc here? You could describe them via @param here as well.

*
* @note Pattern is a string representation of the regular expression.
*
* @group string_funcs
* @since 2.4.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto for 3.0.0

*/
def split(str: Column, pattern: String, limit: Int): Column = withExpr {
StringSplit(str.expr, Literal(pattern), Literal(limit))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,7 @@ FROM (
encode(string(id + 3), 'utf-8') col4
FROM range(10)
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The result is apparently wrong. Maybe we need ; here.


-- split function
select split('aa1cc2ee', '[1-9]+', 2);
select split('aa1cc2ee', '[1-9]+');
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 15
-- Number of queries: 16


-- !query 0
Expand Down Expand Up @@ -155,9 +155,32 @@ FROM (
encode(string(id + 3), 'utf-8') col4
FROM range(10)
)

select split('aa1cc2ee', '[1-9]+', 2)
-- !query 14 schema
struct<plan:string>
struct<>
-- !query 14 output
== Physical Plan ==
*Project [concat(cast(id#xL as string), cast(encode(cast((id#xL + 2) as string), utf-8) as string), cast(encode(cast((id#xL + 3) as string), utf-8) as string)) AS col#x]
+- *Range (0, 10, step=1, splits=2)
org.apache.spark.sql.catalyst.parser.ParseException

mismatched input 'select' expecting <EOF>(line 10, pos 0)

== SQL ==
EXPLAIN SELECT (col1 || (col3 || col4)) col
FROM (
SELECT
string(id) col1,
encode(string(id + 2), 'utf-8') col3,
encode(string(id + 3), 'utf-8') col4
FROM range(10)
)

select split('aa1cc2ee', '[1-9]+', 2)
^^^


-- !query 15
select split('aa1cc2ee', '[1-9]+')
-- !query 15 schema
struct<split(aa1cc2ee, [1-9]+, -1):array<string>>
-- !query 15 output
["aa","cc","ee"]
Original file line number Diff line number Diff line change
Expand Up @@ -329,16 +329,52 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext {
Row(" "))
}

test("string split function") {
val df = Seq(("aa2bb3cc", "[1-9]+")).toDF("a", "b")
test("string split function with no limit") {
val df = Seq(("aa2bb3cc4", "[1-9]+")).toDF("a", "b")

checkAnswer(
df.select(split($"a", "[1-9]+")),
Row(Seq("aa", "bb", "cc")))
Row(Seq("aa", "bb", "cc", "")))

checkAnswer(
df.selectExpr("split(a, '[1-9]+')"),
Row(Seq("aa", "bb", "cc", "")))
}

test("string split function with limit explicitly set to 0") {
val df = Seq(("aa2bb3cc4", "[1-9]+")).toDF("a", "b")

checkAnswer(
df.select(split($"a", "[1-9]+", 0)),
Row(Seq("aa", "bb", "cc")))

checkAnswer(
df.selectExpr("split(a, '[1-9]+', 0)"),
Row(Seq("aa", "bb", "cc")))
}

test("string split function with positive limit") {
val df = Seq(("aa2bb3cc4", "[1-9]+")).toDF("a", "b")

checkAnswer(
df.select(split($"a", "[1-9]+", 2)),
Row(Seq("aa", "bb3cc4")))

checkAnswer(
df.selectExpr("split(a, '[1-9]+', 2)"),
Row(Seq("aa", "bb3cc4")))
}

test("string split function with negative limit") {
val df = Seq(("aa2bb3cc4", "[1-9]+")).toDF("a", "b")

checkAnswer(
df.select(split($"a", "[1-9]+", -2)),
Row(Seq("aa", "bb", "cc", "")))

checkAnswer(
df.selectExpr("split(a, '[1-9]+', -2)"),
Row(Seq("aa", "bb", "cc", "")))
}

test("string / binary length function") {
Expand Down