diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 9f92181b34df1..ae29cfe8119f6 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -26,6 +26,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.commons.codec.binary.{Base64 => CommonsBase64} +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -2082,6 +2083,65 @@ case class UnBase64(child: Expression) } } +object Decode { + def createExpr(params: Seq[Expression]): Expression = { + params.length match { + case 0 | 1 => + throw new AnalysisException("Invalid number of arguments for function decode. " + + s"Expected: 2; Found: ${params.length}") + case 2 => StringDecode(params.head, params.last) + case _ => + val input = params.head + val other = params.tail + val itr = other.iterator + var default: Expression = Literal.create(null, StringType) + val branches = ArrayBuffer.empty[(Expression, Expression)] + while (itr.hasNext) { + val search = itr.next + if (itr.hasNext) { + val condition = EqualTo(input, search) + branches += ((condition, itr.next)) + } else { + default = search + } + } + CaseWhen(branches.seq, default) + } + } +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + |_FUNC_(bin, charset) - Decodes the first argument using the second argument character set. + | + |_FUNC_(expr, search, result [, search, result ] ... [, default]) - Decode compares expr + | to each search value one by one. If expr is equal to a search, returns the corresponding result. + | If no match is found, then Oracle returns default. If default is omitted, returns null. + """, + examples = """ + Examples: + > SELECT _FUNC_(encode('abc', 'utf-8'), 'utf-8'); + abc + > SELECT _FUNC_(2, 1, 'Southlake', 2, 'San Francisco', 3, 'New Jersey', 4, 'Seattle', 'Non domestic'); + San Francisco + > SELECT _FUNC_(6, 1, 'Southlake', 2, 'San Francisco', 3, 'New Jersey', 4, 'Seattle', 'Non domestic'); + Non domestic + > SELECT _FUNC_(6, 1, 'Southlake', 2, 'San Francisco', 3, 'New Jersey', 4, 'Seattle'); + NULL + """, + since = "3.2.0") +// scalastyle:on line.size.limit +case class Decode(params: Seq[Expression], child: Expression) extends RuntimeReplaceable { + + def this(params: Seq[Expression]) = { + this(params, Decode.createExpr(params)) + } + + override def flatArguments: Iterator[Any] = Iterator(params) + override def exprsReplaced: Seq[Expression] = params +} + /** * Decodes the first argument into a String using the provided character set * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). @@ -2097,7 +2157,7 @@ case class UnBase64(child: Expression) """, since = "1.5.0") // scalastyle:on line.size.limit -case class Decode(bin: Expression, charset: Expression) +case class StringDecode(bin: Expression, charset: Expression) extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def left: Expression = bin diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index adaabfe4d32bb..bbc9a47701a7d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -104,7 +104,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-22543: split large if expressions into blocks due to JVM code size limit") { var strExpr: Expression = Literal("abc") for (_ <- 1 to 150) { - strExpr = Decode(Encode(strExpr, "utf-8"), "utf-8") + strExpr = StringDecode(Encode(strExpr, "utf-8"), "utf-8") } val expressions = Seq(If(EqualTo(strExpr, strExpr), strExpr, strExpr)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 730574a4b9846..b19ea6cbc0d4d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -349,23 +349,23 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // scalastyle:off // non ascii characters are not allowed in the code, so we disable the scalastyle here. checkEvaluation( - Decode(Encode(Literal("大千世界"), Literal("UTF-16LE")), Literal("UTF-16LE")), "大千世界") + StringDecode(Encode(Literal("大千世界"), Literal("UTF-16LE")), Literal("UTF-16LE")), "大千世界") checkEvaluation( - Decode(Encode(a, Literal("utf-8")), Literal("utf-8")), "大千世界", create_row("大千世界")) + StringDecode(Encode(a, Literal("utf-8")), Literal("utf-8")), "大千世界", create_row("大千世界")) checkEvaluation( - Decode(Encode(a, Literal("utf-8")), Literal("utf-8")), "", create_row("")) + StringDecode(Encode(a, Literal("utf-8")), Literal("utf-8")), "", create_row("")) // scalastyle:on checkEvaluation(Encode(a, Literal("utf-8")), null, create_row(null)) checkEvaluation(Encode(Literal.create(null, StringType), Literal("utf-8")), null) checkEvaluation(Encode(a, Literal.create(null, StringType)), null, create_row("")) - checkEvaluation(Decode(b, Literal("utf-8")), null, create_row(null)) - checkEvaluation(Decode(Literal.create(null, BinaryType), Literal("utf-8")), null) - checkEvaluation(Decode(b, Literal.create(null, StringType)), null, create_row(null)) + checkEvaluation(StringDecode(b, Literal("utf-8")), null, create_row(null)) + checkEvaluation(StringDecode(Literal.create(null, BinaryType), Literal("utf-8")), null) + checkEvaluation(StringDecode(b, Literal.create(null, StringType)), null, create_row(null)) // Test escaping of charset GenerateUnsafeProjection.generate(Encode(a, Literal("\"quote")) :: Nil) - GenerateUnsafeProjection.generate(Decode(b, Literal("\"quote")) :: Nil) + GenerateUnsafeProjection.generate(StringDecode(b, Literal("\"quote")) :: Nil) } test("initcap unit test") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 9861d21d3a430..7620003a82781 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2438,7 +2438,7 @@ object functions { * @since 1.5.0 */ def decode(value: Column, charset: String): Column = withExpr { - Decode(value.expr, lit(charset).expr) + StringDecode(value.expr, lit(charset).expr) } /** diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql index f5ed2036dc8ac..80b4b8ca8cd54 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql @@ -53,3 +53,13 @@ SELECT trim(TRAILING 'xy' FROM 'TURNERyxXxy'); -- Check lpad/rpad with invalid length parameter SELECT lpad('hi', 'invalid_length'); SELECT rpad('hi', 'invalid_length'); + +-- decode +select decode(); +select decode(encode('abc', 'utf-8')); +select decode(encode('abc', 'utf-8'), 'utf-8'); +select decode(1, 1, 'Southlake'); +select decode(2, 1, 'Southlake'); +select decode(2, 1, 'Southlake', 2, 'San Francisco', 3, 'New Jersey', 4, 'Seattle', 'Non domestic'); +select decode(6, 1, 'Southlake', 2, 'San Francisco', 3, 'New Jersey', 4, 'Seattle', 'Non domestic'); +select decode(6, 1, 'Southlake', 2, 'San Francisco', 3, 'New Jersey', 4, 'Seattle'); \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out index d5c0acb40bb1e..3164d462f8464 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 36 +-- Number of queries: 44 -- !query @@ -294,3 +294,69 @@ struct<> -- !query output java.lang.NumberFormatException invalid input syntax for type numeric: invalid_length + + +-- !query +select decode() +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +Invalid number of arguments for function decode. Expected: 2; Found: 0;; line 1 pos 7 + + +-- !query +select decode(encode('abc', 'utf-8')) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +Invalid number of arguments for function decode. Expected: 2; Found: 1;; line 1 pos 7 + + +-- !query +select decode(encode('abc', 'utf-8'), 'utf-8') +-- !query schema +struct +-- !query output +abc + + +-- !query +select decode(1, 1, 'Southlake') +-- !query schema +struct +-- !query output +Southlake + + +-- !query +select decode(2, 1, 'Southlake') +-- !query schema +struct +-- !query output +NULL + + +-- !query +select decode(2, 1, 'Southlake', 2, 'San Francisco', 3, 'New Jersey', 4, 'Seattle', 'Non domestic') +-- !query schema +struct +-- !query output +San Francisco + + +-- !query +select decode(6, 1, 'Southlake', 2, 'San Francisco', 3, 'New Jersey', 4, 'Seattle', 'Non domestic') +-- !query schema +struct +-- !query output +Non domestic + + +-- !query +select decode(6, 1, 'Southlake', 2, 'San Francisco', 3, 'New Jersey', 4, 'Seattle') +-- !query schema +struct +-- !query output +NULL \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index 20c31b140b009..020a095d72e85 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 36 +-- Number of queries: 44 -- !query @@ -290,3 +290,69 @@ SELECT rpad('hi', 'invalid_length') struct -- !query output NULL + + +-- !query +select decode() +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +Invalid number of arguments for function decode. Expected: 2; Found: 0;; line 1 pos 7 + + +-- !query +select decode(encode('abc', 'utf-8')) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +Invalid number of arguments for function decode. Expected: 2; Found: 1;; line 1 pos 7 + + +-- !query +select decode(encode('abc', 'utf-8'), 'utf-8') +-- !query schema +struct +-- !query output +abc + + +-- !query +select decode(1, 1, 'Southlake') +-- !query schema +struct +-- !query output +Southlake + + +-- !query +select decode(2, 1, 'Southlake') +-- !query schema +struct +-- !query output +NULL + + +-- !query +select decode(2, 1, 'Southlake', 2, 'San Francisco', 3, 'New Jersey', 4, 'Seattle', 'Non domestic') +-- !query schema +struct +-- !query output +San Francisco + + +-- !query +select decode(6, 1, 'Southlake', 2, 'San Francisco', 3, 'New Jersey', 4, 'Seattle', 'Non domestic') +-- !query schema +struct +-- !query output +Non domestic + + +-- !query +select decode(6, 1, 'Southlake', 2, 'San Francisco', 3, 'New Jersey', 4, 'Seattle') +-- !query schema +struct +-- !query output +NULL \ No newline at end of file