diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_base64.explain b/connect/common/src/test/resources/query-tests/explain-results/function_base64.explain index f80f3522190d..d3a250919ea5 100644 --- a/connect/common/src/test/resources/query-tests/explain-results/function_base64.explain +++ b/connect/common/src/test/resources/query-tests/explain-results/function_base64.explain @@ -1,2 +1,2 @@ -Project [base64(cast(g#0 as binary)) AS base64(CAST(g AS BINARY))#0] +Project [static_invoke(Base64.encode(cast(g#0 as binary), false)) AS base64(CAST(g AS BINARY))#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index f25f58731c8c..b188b9c2630f 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -2682,24 +2682,40 @@ case class Chr(child: Expression) """, since = "1.5.0", group = "string_funcs") -case class Base64(child: Expression) - extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { +case class Base64(child: Expression, chunkBase64: Boolean) + extends UnaryExpression with RuntimeReplaceable with ImplicitCastInputTypes { + + def this(expr: Expression) = this(expr, SQLConf.get.chunkBase64StringEnabled) override def dataType: DataType = SQLConf.get.defaultStringType override def inputTypes: Seq[DataType] = Seq(BinaryType) - protected override def nullSafeEval(bytes: Any): Any = { - UTF8String.fromBytes(JBase64.getMimeEncoder.encode(bytes.asInstanceOf[Array[Byte]])) - } + override def replacement: Expression = StaticInvoke( + classOf[Base64], + dataType, + "encode", + Seq(child, Literal(chunkBase64, BooleanType)), + Seq(BinaryType, BooleanType)) - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (child) => { - s"""${ev.value} = UTF8String.fromBytes( - ${classOf[JBase64].getName}.getMimeEncoder().encode($child)); - """}) - } + override def toString: String = s"$prettyName($child)" - override protected def withNewChildInternal(newChild: Expression): Base64 = copy(child = newChild) + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) +} + +object Base64 { + def apply(expr: Expression): Base64 = new Base64(expr) + + private lazy val nonChunkEncoder = JBase64.getMimeEncoder(-1, Array()) + + def encode(input: Array[Byte], chunkBase64: Boolean): UTF8String = { + val encoder = if (chunkBase64) { + JBase64.getMimeEncoder + } else { + nonChunkEncoder + } + UTF8String.fromBytes(encoder.encode(input)) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 6ca831f99304..65beb21d59d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3525,6 +3525,15 @@ object SQLConf { .booleanConf .createWithDefault(false) + val CHUNK_BASE64_STRING_ENABLED = buildConf("spark.sql.legacy.chunkBase64String.enabled") + .internal() + .doc("Whether to truncate string generated by the `Base64` function. When true, base64" + + " strings generated by the base64 function are chunked into lines of at most 76" + + " characters. When false, the base64 strings are not chunked.") + .version("3.5.2") + .booleanConf + .createWithDefault(false) + val ENABLE_DEFAULT_COLUMNS = buildConf("spark.sql.defaultColumn.enabled") .internal() @@ -5856,6 +5865,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def ansiRelationPrecedence: Boolean = ansiEnabled && getConf(ANSI_RELATION_PRECEDENCE) + def chunkBase64StringEnabled: Boolean = getConf(CHUNK_BASE64_STRING_ENABLED) + def timestampType: AtomicType = getConf(TIMESTAMP_TYPE) match { case "TIMESTAMP_LTZ" => // For historical reason, the TimestampType maps to TIMESTAMP WITH LOCAL TIME ZONE diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index ebd724543481..2ad8652f2b31 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -509,6 +509,19 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { GenerateUnsafeProjection.generate(StringDecode(b, Literal("\"quote")).replacement :: Nil) } + test("SPARK-47307: base64 encoding without chunking") { + val longString = "a" * 58 + val encoded = "YWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYQ==" + withSQLConf(SQLConf.CHUNK_BASE64_STRING_ENABLED.key -> "false") { + checkEvaluation(Base64(Literal(longString.getBytes)), encoded) + } + val chunkEncoded = + s"YWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFh\r\nYQ==" + withSQLConf(SQLConf.CHUNK_BASE64_STRING_ENABLED.key -> "true") { + checkEvaluation(Base64(Literal(longString.getBytes)), chunkEncoded) + } + } + test("initcap unit test") { checkEvaluation(InitCap(Literal.create(null, StringType)), null) checkEvaluation(InitCap(Literal("a b")), "A B")