diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 259556826ad9..f2237540afde 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -2426,21 +2426,34 @@ case class Chr(child: Expression) """, since = "1.5.0", group = "string_funcs") -case class Base64(child: Expression) +case class Base64(child: Expression, chunkBase64: Boolean = SQLConf.get.chunkBase64StringEnabled) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { + lazy val encoder: JBase64.Encoder = if (chunkBase64) { + JBase64.getMimeEncoder + } else { + JBase64.getMimeEncoder(-1, Array()) + } + override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(BinaryType) protected override def nullSafeEval(bytes: Any): Any = { - UTF8String.fromBytes(JBase64.getMimeEncoder.encode(bytes.asInstanceOf[Array[Byte]])) + UTF8String.fromBytes(encoder.encode(bytes.asInstanceOf[Array[Byte]])) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (child) => { - s"""${ev.value} = UTF8String.fromBytes( - ${classOf[JBase64].getName}.getMimeEncoder().encode($child)); - """}) + if (chunkBase64) { + s"""${ev.value} = UTF8String.fromBytes( + ${classOf[JBase64].getName}.getMimeEncoder().encode($child)); + """ + } else { + s"""${ev.value} = UTF8String.fromBytes( + ${classOf[JBase64].getName}.getMimeEncoder(-1, new byte[0]).encode($child)); + """ + } + }) } override protected def withNewChildInternal(newChild: Expression): Base64 = copy(child = newChild) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index bc4734775c77..f1ece91f25e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3318,6 +3318,13 @@ object SQLConf { .booleanConf .createWithDefault(false) + val CHUNK_BASE_64_STRING_ENABLED = buildConf("spark.sql.chunkBase64String.enabled") + .doc("When true, base64 strings generated by the base64 function are chunked into lines of " + + "at most 76 characters. When false, the base64 strings are not chunked.") + .version("4.0.0") + .booleanConf + .createWithDefault(true) + val ENABLE_DEFAULT_COLUMNS = buildConf("spark.sql.defaultColumn.enabled") .internal() @@ -5398,6 +5405,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def ansiRelationPrecedence: Boolean = ansiEnabled && getConf(ANSI_RELATION_PRECEDENCE) + def chunkBase64StringEnabled: Boolean = getConf(CHUNK_BASE_64_STRING_ENABLED) + def timestampType: AtomicType = getConf(TIMESTAMP_TYPE) match { case "TIMESTAMP_LTZ" => // For historical reason, the TimestampType maps to TIMESTAMP WITH LOCAL TIME ZONE diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 98f33e209994..a3b6b4ac865b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -506,6 +506,13 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { GenerateUnsafeProjection.generate(StringDecode(b, Literal("\"quote")) :: Nil) } + test("SPARK-47307: base64 encoding without chunking") { + val longString = "a" * 58 + val encoded = "YWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYQ==" + + checkEvaluation(Base64(Literal(longString.getBytes), false), encoded, create_row("abcdefgh")) + } + test("initcap unit test") { checkEvaluation(InitCap(Literal.create(null, StringType)), null) checkEvaluation(InitCap(Literal("a b")), "A B") diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/charvarchar.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/charvarchar.sql.out index 4f556d6dbc0b..d1f304c66c32 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/charvarchar.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/charvarchar.sql.out @@ -458,10 +458,24 @@ Project [ascii(c7#x) AS ascii(c7)#x, ascii(c8#x) AS ascii(c8)#x, ascii(v#x) AS a -- !query select base64(c7), base64(c8), base64(v), ascii(s) from char_tbl4 -- !query analysis -Project [base64(cast(c7#x as binary)) AS base64(c7)#x, base64(cast(c8#x as binary)) AS base64(c8)#x, base64(cast(v#x as binary)) AS base64(v)#x, ascii(s#x) AS ascii(s)#x] -+- SubqueryAlias spark_catalog.default.char_tbl4 - +- Project [staticinvoke(class org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils, StringType, readSidePadding, c7#x, 7, true, false, true) AS c7#x, staticinvoke(class org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils, StringType, readSidePadding, c8#x, 8, true, false, true) AS c8#x, v#x, s#x] - +- Relation spark_catalog.default.char_tbl4[c7#x,c8#x,v#x,s#x] parquet +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + "sqlState" : "42605", + "messageParameters" : { + "actualNum" : "1", + "docroot" : "https://spark.apache.org/docs/latest", + "expectedNum" : "0", + "functionName" : "`base64`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 17, + "fragment" : "base64(c7)" + } ] +} -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/charvarchar.sql.out b/sql/core/src/test/resources/sql-tests/results/charvarchar.sql.out index 3ad363abd31b..fcc7ffe3ac2f 100644 --- a/sql/core/src/test/resources/sql-tests/results/charvarchar.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/charvarchar.sql.out @@ -839,17 +839,26 @@ NULL NULL NULL NULL -- !query select base64(c7), base64(c8), base64(v), ascii(s) from char_tbl4 -- !query schema -struct +struct<> -- !query output -NULL NULL NULL NULL -NULL NULL Uw== NULL -TiAgICAgIA== TiAgICAgICA= TiA= 78 -TmUgICAgIA== TmUgICAgICA= U3A= 78 -TmV0ICAgIA== TmV0ICAgICA= U3BhICA= 78 -TmV0RSAgIA== TmV0RSAgICA= U3Bhcg== 78 -TmV0RWEgIA== TmV0RWEgICA= U3Bhcmsg 78 -TmV0RWFzIA== TmV0RWFzICA= U3Bhcms= 78 -TmV0RWFzZQ== TmV0RWFzZSA= U3Bhcmst 78 +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + "sqlState" : "42605", + "messageParameters" : { + "actualNum" : "1", + "docroot" : "https://spark.apache.org/docs/latest", + "expectedNum" : "0", + "functionName" : "`base64`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 17, + "fragment" : "base64(c7)" + } ] +} -- !query diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/Base64Benchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/Base64Benchmark.scala index 3ad6baea84f2..a021477e88a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/Base64Benchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/Base64Benchmark.scala @@ -56,7 +56,7 @@ object Base64Benchmark extends SqlBasedBenchmark { Seq(1, 3, 5, 7).map { len => val benchmark = new Benchmark(s"encode for $len", N, output = output) benchmark.addCase("java", 3) { _ => - doEncode(len, x => java.util.Base64.getMimeEncoder().encode(x)) + doEncode(len, x => java.util.Base64.getMimeEncoder.encode(x)) } benchmark.addCase(s"apache", 3) { _ => doEncode(len, org.apache.commons.codec.binary.Base64.encodeBase64)