Skip to content

Commit 449e2c9

Browse files
committed
Fix
1 parent 09fd22e commit 449e2c9

File tree

3 files changed

+26
-80
lines changed

3 files changed

+26
-80
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,13 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions.codegen;
1919

20-
import java.nio.charset.StandardCharsets;
21-
2220
import org.apache.spark.unsafe.Platform;
2321
import org.apache.spark.unsafe.array.ByteArrayMethods;
2422
import org.apache.spark.unsafe.types.UTF8String;
2523

2624
/**
27-
* A helper class to write `UTF8String`, `String`, and `byte[]` data into an internal byte buffer
28-
* and get written data as `UTF8String`.
25+
* A helper class to write {@link UTF8String}s to an internal buffer and build the concatenated
26+
* {@link UTF8String} at the end.
2927
*/
3028
public class UTF8StringBuilder {
3129

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
619619
"""
620620
}
621621

622-
private[this] def codegenWriteArrayElemCode(et: DataType, ctx: CodegenContext): String = {
622+
private def writeArrayToStringBuilder(
623+
et: DataType,
624+
arTerm: String,
625+
bufferTerm: String,
626+
ctx: CodegenContext): String = {
623627
val elementToStringCode = castToStringCode(et, ctx)
624628
val funcName = ctx.freshName("elementToString")
625629
val elementToStringFunc = ctx.addNewFunction(funcName,
@@ -632,29 +636,22 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
632636
""".stripMargin)
633637

634638
val loopIndex = ctx.freshName("loopIndex")
635-
val writeArrayToBuffer = ctx.freshName("writeArrayToBuffer")
636-
val arTerm = ctx.freshName("arTerm")
637-
val bufferClass = classOf[UTF8StringBuilder].getName
638-
val bufferTerm = ctx.freshName("bufferTerm")
639-
ctx.addNewFunction(writeArrayToBuffer,
640-
s"""
641-
|private void $writeArrayToBuffer(ArrayData $arTerm, $bufferClass $bufferTerm) {
642-
| $bufferTerm.append("[");
643-
| if ($arTerm.numElements() > 0) {
644-
| if (!$arTerm.isNullAt(0)) {
645-
| $bufferTerm.append($elementToStringFunc(${ctx.getValue(arTerm, et, "0")}));
646-
| }
647-
| for (int $loopIndex = 1; $loopIndex < $arTerm.numElements(); $loopIndex++) {
648-
| $bufferTerm.append(",");
649-
| if (!$arTerm.isNullAt($loopIndex)) {
650-
| $bufferTerm.append(" ");
651-
| $bufferTerm.append($elementToStringFunc(${ctx.getValue(arTerm, et, loopIndex)}));
652-
| }
653-
| }
654-
| }
655-
| $bufferTerm.append("]");
656-
|}
657-
""".stripMargin)
639+
s"""
640+
|$bufferTerm.append("[");
641+
|if ($arTerm.numElements() > 0) {
642+
| if (!$arTerm.isNullAt(0)) {
643+
| $bufferTerm.append($elementToStringFunc(${ctx.getValue(arTerm, et, "0")}));
644+
| }
645+
| for (int $loopIndex = 1; $loopIndex < $arTerm.numElements(); $loopIndex++) {
646+
| $bufferTerm.append(",");
647+
| if (!$arTerm.isNullAt($loopIndex)) {
648+
| $bufferTerm.append(" ");
649+
| $bufferTerm.append($elementToStringFunc(${ctx.getValue(arTerm, et, loopIndex)}));
650+
| }
651+
| }
652+
|}
653+
|$bufferTerm.append("]");
654+
""".stripMargin
658655
}
659656

660657
private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = {
@@ -672,10 +669,10 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
672669
(c, evPrim, evNull) => {
673670
val bufferTerm = ctx.freshName("bufferTerm")
674671
val bufferClass = classOf[UTF8StringBuilder].getName
675-
val writeArrayElemCode = codegenWriteArrayElemCode(et, ctx)
672+
val writeArrayElemCode = writeArrayToStringBuilder(et, c, bufferTerm, ctx)
676673
s"""
677674
|$bufferClass $bufferTerm = new $bufferClass();
678-
|$writeArrayElemCode($c, $bufferTerm);
675+
|$writeArrayElemCode;
679676
|$evPrim = $bufferTerm.build();
680677
""".stripMargin
681678
}

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 1 addition & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql
2020
import java.io.File
2121
import java.math.MathContext
2222
import java.net.{MalformedURLException, URL}
23-
import java.sql.{Date, Timestamp}
23+
import java.sql.Timestamp
2424
import java.util.concurrent.atomic.AtomicBoolean
2525

2626
import org.apache.spark.{AccumulatorSuite, SparkException}
@@ -2773,53 +2773,4 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
27732773
}
27742774
}
27752775
}
2776-
2777-
test("SPARK-22825 Cast array to string") {
2778-
Seq("true", "false").foreach { codegen =>
2779-
withSQLConf("spark.sql.codegen.wholeStage" -> codegen) {
2780-
withTable("t") {
2781-
Seq(Seq(0, 1, 2, 3, 4)).toDF("a").write.saveAsTable("t")
2782-
val df = sql("SELECT CAST(a AS STRING) FROM t")
2783-
checkAnswer(df, Row("[0, 1, 2, 3, 4]"))
2784-
}
2785-
withTable("t") {
2786-
Seq(Seq("ab", "cde", "f")).toDF("a").write.saveAsTable("t")
2787-
val df = sql("SELECT CAST(a AS STRING) FROM t")
2788-
checkAnswer(df, Row("[ab, cde, f]"))
2789-
}
2790-
withTable("t") {
2791-
Seq(Seq("ab", null, "c")).toDF("a").write.saveAsTable("t")
2792-
val df = sql("SELECT CAST(a AS STRING) FROM t")
2793-
checkAnswer(df, Row("[ab,, c]"))
2794-
}
2795-
withTable("t") {
2796-
Seq(Seq("ab".getBytes, "cde".getBytes, "f".getBytes)).toDF("a").write.saveAsTable("t")
2797-
val df = sql("SELECT CAST(a AS STRING) FROM t")
2798-
checkAnswer(df, Row("[ab, cde, f]"))
2799-
}
2800-
withTable("t") {
2801-
Seq(Seq("2014-12-03", "2014-12-04", "2014-12-06").map(Date.valueOf))
2802-
.toDF("a").write.saveAsTable("t")
2803-
val df = sql("SELECT CAST(a AS STRING) FROM t")
2804-
checkAnswer(df, Row("[2014-12-03, 2014-12-04, 2014-12-06]"))
2805-
}
2806-
withTable("t") {
2807-
Seq(Seq("2014-12-03 13:01:00", "2014-12-04 15:05:00").map(Timestamp.valueOf))
2808-
.toDF("a").write.saveAsTable("t")
2809-
val df = sql("SELECT CAST(a AS STRING) FROM t")
2810-
checkAnswer(df, Row("[2014-12-03 13:01:00, 2014-12-04 15:05:00]"))
2811-
}
2812-
withTable("t") {
2813-
Seq(Seq(Seq(1, 2), Seq(3), Seq(4, 5, 6))).toDF("a").write.saveAsTable("t")
2814-
val df = sql("SELECT CAST(a AS STRING) FROM t")
2815-
checkAnswer(df, Row("[[1, 2], [3], [4, 5, 6]]"))
2816-
}
2817-
withTable("t") {
2818-
Seq(Seq(Seq(Seq("a"), Seq("b", "c")), Seq(Seq("d")))).toDF("a").write.saveAsTable("t")
2819-
val df = sql("SELECT CAST(a AS STRING) FROM t")
2820-
checkAnswer(df, Row("[[[a], [b, c]], [[d]]]"))
2821-
}
2822-
}
2823-
}
2824-
}
28252776
}

0 commit comments

Comments
 (0)