From c02b10d21d3d1ebaddd93c581444412e67dc7ef0 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 10 Nov 2016 23:29:16 -0800 Subject: [PATCH 1/3] fix --- .../expressions/stringExpressions.scala | 4 + .../spark/sql/StringFunctionsSuite.scala | 92 ++++++++++++++++++- 2 files changed, 95 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 5f533fecf8d0..2bece63506e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -44,6 +44,10 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @ExpressionDescription( usage = "_FUNC_(str1, str2, ..., strN) - Returns the concatenation of `str1`, `str2`, ..., `strN`.", extended = """ + Arguments: + str - An expression that returns a value of a character string. If any argument is null, the + result is the null value. + Examples: > SELECT _FUNC_('Spark','SQL'); SparkSQL diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index bcc235104995..ada36dbbbec5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -17,14 +17,18 @@ package org.apache.spark.sql +import java.sql.{Date, Timestamp} + +import org.apache.spark.rdd.RDD import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ class StringFunctionsSuite extends QueryTest with SharedSQLContext { import testImplicits._ - test("string concat") { + test("string concat - basic") { val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c") checkAnswer( @@ -36,6 +40,27 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { Row("ab", null)) } + test("string concat - all compatible types") { + val allTypeData = AllTypeData(spark) + val df = allTypeData.dataFrame + checkAnswer( + df.select(concat(allTypeData.getStringCompatibleColumns: _*)), + Row("11111.011.011.010001970-01-011970-01-01 00:00:00aatrue") :: + Row("22222.022.022.020001970-02-021970-01-01 00:00:05bbbbfalse") :: Nil) + } + + test("string concat - unsupported types") { + val allTypeData = AllTypeData(spark) + val df = allTypeData.dataFrame + + Seq(allTypeData.mapCol, allTypeData.arrayCol, allTypeData.structCol).foreach { col => + val e = intercept[AnalysisException] { + df.select(concat(allTypeData.stringCol, col)) + }.getMessage + assert(e.contains("argument 2 requires string type")) + } + } + test("string concat_ws") { val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c") @@ -452,3 +477,68 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { } } + +case class AllTypeData (spark: SparkSession) { + private val intSeq = Seq(1, 2) + private val doubleSeq = Seq(1.01d, 2.02d) + private val stringSeq = Seq("a", "bb") + private val booleanSeq = Seq(true, false) + private val dateSeq = Seq("1970-01-01", "1970-02-02").map(Date.valueOf) + private val timestampSeq = Seq("1970-01-01 00:00:00", "1970-01-01 00:00:05") + .map(Timestamp.valueOf) + private val arraySeq = Seq(Seq(1), Seq(2)) + private val mapSeq = Seq(Map("a" -> "1", "b" -> "2"), Map("d" -> "3", "e" -> "4")) + private val structSeq = Seq(Row("d"), Row("c")) + + private val allTypeSchema = StructType(Seq( + StructField("byteCol", ByteType), + StructField("shortCol", ShortType), + StructField("intCol", IntegerType), + StructField("longCol", LongType), + StructField("floatCol", FloatType), + StructField("doubleCol", DoubleType), + StructField("decimalCol", DecimalType(10, 5)), + StructField("dateCol", DateType), + StructField("timestampCol", TimestampType), + StructField("stringCol", StringType), + StructField("binaryCol", BinaryType), + StructField("booleanCol", BooleanType), + StructField("arrayCol", ArrayType(IntegerType, containsNull = true)), + StructField("mapCol", MapType(StringType, StringType)), + StructField("structCol", new StructType().add("a", StringType)) + )) + + private val rowRDD: RDD[Row] = spark.sparkContext.parallelize(intSeq.indices.map { i => + Row(intSeq(i).toByte, intSeq(i).toShort, intSeq(i), intSeq(i).toLong, + doubleSeq(i).toFloat, doubleSeq(i), Decimal(doubleSeq(i)), + dateSeq(i), timestampSeq(i), + stringSeq(i), stringSeq(i).getBytes, + booleanSeq(i), + arraySeq(i), + mapSeq(i), + structSeq(i)) + }) + + val dataFrame: DataFrame = spark.createDataFrame(rowRDD, allTypeSchema) + + val byteCol = Column("byteCol") + val shortCol = Column("shortCol") + val intCol = Column("intCol") + val longCol = Column("longCol") + val floatCol = Column("floatCol") + val doubleCol = Column("doubleCol") + val decimalCol = Column("decimalCol") + val dateCol = Column("dateCol") + val timestampCol = Column("timestampCol") + val stringCol = Column("stringCol") + val binaryCol = Column("binaryCol") + val booleanCol = Column("booleanCol") + val arrayCol = Column("arrayCol") + val mapCol = Column("mapCol") + val structCol = Column("structCol") + + def getStringCompatibleColumns: Seq[Column] = { + Seq(byteCol, shortCol, intCol, longCol, floatCol, doubleCol, decimalCol, + dateCol, timestampCol, stringCol, binaryCol, booleanCol) + } +} From e782a5e316a29359b2218a6751bd61372149e857 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 11 Nov 2016 09:30:26 -0800 Subject: [PATCH 2/3] fix --- .../scala/org/apache/spark/sql/StringFunctionsSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index ada36dbbbec5..d508f109660b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -41,7 +41,7 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { } test("string concat - all compatible types") { - val allTypeData = AllTypeData(spark) + val allTypeData = AllTypeTestData(spark) val df = allTypeData.dataFrame checkAnswer( df.select(concat(allTypeData.getStringCompatibleColumns: _*)), @@ -50,7 +50,7 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { } test("string concat - unsupported types") { - val allTypeData = AllTypeData(spark) + val allTypeData = AllTypeTestData(spark) val df = allTypeData.dataFrame Seq(allTypeData.mapCol, allTypeData.arrayCol, allTypeData.structCol).foreach { col => @@ -478,7 +478,7 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { } } -case class AllTypeData (spark: SparkSession) { +case class AllTypeTestData (spark: SparkSession) { private val intSeq = Seq(1, 2) private val doubleSeq = Seq(1.01d, 2.02d) private val stringSeq = Seq("a", "bb") From 1ba3d6716781c1aca606ad416c9b45bef185571f Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 13 Nov 2016 01:18:29 -0800 Subject: [PATCH 3/3] concat_ws --- .../expressions/stringExpressions.scala | 22 ++++++- .../spark/sql/StringFunctionsSuite.scala | 57 +++++++++++++++++-- 2 files changed, 72 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 2bece63506e5..98d12c2ef446 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -26,6 +26,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.types._ @@ -45,7 +46,9 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} usage = "_FUNC_(str1, str2, ..., strN) - Returns the concatenation of `str1`, `str2`, ..., `strN`.", extended = """ Arguments: - str - An expression that returns a value of a character string. If any argument is null, the + str - The strings to be concatenated. + + The arguments are expressions that return a value of a character string. If any argument is null, the result is the null value. Examples: @@ -92,6 +95,15 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas @ExpressionDescription( usage = "_FUNC_(sep, [str | array(str)]+) - Returns the concatenation of the strings separated by `sep`.", extended = """ + Arguments: + sep - The separator for the rest of the arguments. + str | array(str) - The strings to be concatenated. + + The arguments can be expressions that return a value of a character string. The arguments from + the second argument can also be expressions that return array. Minimum number of + arguments is 3. The function ignores null values and returns an empty string if all values + are null. It returns null only if the separator is null. + Examples: > SELECT _FUNC_(' ', Spark', 'SQL'); Spark SQL @@ -112,6 +124,14 @@ case class ConcatWs(children: Seq[Expression]) override def dataType: DataType = StringType + override def checkInputDataTypes(): TypeCheckResult = { + if (children.size < 3) { + TypeCheckFailure("requires at least three arguments") + } else { + super.checkInputDataTypes() + } + } + override def nullable: Boolean = children.head.nullable override def foldable: Boolean = children.forall(_.foldable) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index d508f109660b..d43ed2e0fc90 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -53,7 +53,7 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { val allTypeData = AllTypeTestData(spark) val df = allTypeData.dataFrame - Seq(allTypeData.mapCol, allTypeData.arrayCol, allTypeData.structCol).foreach { col => + Seq(allTypeData.mapCol, allTypeData.arrayIntCol, allTypeData.structCol).foreach { col => val e = intercept[AnalysisException] { df.select(concat(allTypeData.stringCol, col)) }.getMessage @@ -61,7 +61,7 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { } } - test("string concat_ws") { + test("string concat_ws - basic") { val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c") checkAnswer( @@ -71,6 +71,47 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer( df.selectExpr("concat_ws('||', a, b, c)"), Row("a||b")) + + checkAnswer( + df.selectExpr("concat_ws(null, a, b, c)"), + Row(null)) + + checkAnswer( + df.selectExpr("concat_ws(a, b, b)"), + Row("bab")) + + val df1 = Seq[(String, String)]((null, null)).toDF("a", "b") + + checkAnswer( + df1.selectExpr("concat_ws('||', a, b)"), + Row("")) + + val e = intercept[AnalysisException] { + df1.selectExpr("concat_ws('||', b)") + }.getMessage + assert(e.contains("requires at least three arguments")) + } + + test("string concat_ws - all compatible types") { + val allTypeData = AllTypeTestData(spark) + val df = allTypeData.dataFrame + checkAnswer( + df.select(concat_ws("_", + allTypeData.getStringCompatibleColumns :+ allTypeData.arrayStringCol: _*)), + Row("1_1_1_1_1.01_1.01_1.01000_1970-01-01_1970-01-01 00:00:00_a_a_true_a_b") :: + Row("2_2_2_2_2.02_2.02_2.02000_1970-02-02_1970-01-01 00:00:05_bb_bb_false_c_d") :: Nil) + } + + test("string concat_ws - unsupported types") { + val allTypeData = AllTypeTestData(spark) + val df = allTypeData.dataFrame + + Seq(allTypeData.mapCol, allTypeData.arrayIntCol, allTypeData.structCol).foreach { col => + val e = intercept[AnalysisException] { + df.select(concat_ws("_", col, col)) + }.getMessage + assert(e.contains("argument 2 requires (array or string) type")) + } } test("string elt") { @@ -486,7 +527,8 @@ case class AllTypeTestData (spark: SparkSession) { private val dateSeq = Seq("1970-01-01", "1970-02-02").map(Date.valueOf) private val timestampSeq = Seq("1970-01-01 00:00:00", "1970-01-01 00:00:05") .map(Timestamp.valueOf) - private val arraySeq = Seq(Seq(1), Seq(2)) + private val arrayIntSeq = Seq(Seq(1), Seq(2)) + private val arrayStringSeq = Seq(Seq("a", "b"), Seq("c", "d")) private val mapSeq = Seq(Map("a" -> "1", "b" -> "2"), Map("d" -> "3", "e" -> "4")) private val structSeq = Seq(Row("d"), Row("c")) @@ -503,7 +545,8 @@ case class AllTypeTestData (spark: SparkSession) { StructField("stringCol", StringType), StructField("binaryCol", BinaryType), StructField("booleanCol", BooleanType), - StructField("arrayCol", ArrayType(IntegerType, containsNull = true)), + StructField("arrayIntCol", ArrayType(IntegerType, containsNull = true)), + StructField("arrayStringCol", ArrayType(StringType, containsNull = true)), StructField("mapCol", MapType(StringType, StringType)), StructField("structCol", new StructType().add("a", StringType)) )) @@ -514,7 +557,8 @@ case class AllTypeTestData (spark: SparkSession) { dateSeq(i), timestampSeq(i), stringSeq(i), stringSeq(i).getBytes, booleanSeq(i), - arraySeq(i), + arrayIntSeq(i), + arrayStringSeq(i), mapSeq(i), structSeq(i)) }) @@ -533,7 +577,8 @@ case class AllTypeTestData (spark: SparkSession) { val stringCol = Column("stringCol") val binaryCol = Column("binaryCol") val booleanCol = Column("booleanCol") - val arrayCol = Column("arrayCol") + val arrayIntCol = Column("arrayIntCol") + val arrayStringCol = Column("arrayStringCol") val mapCol = Column("mapCol") val structCol = Column("structCol")