Skip to content

Commit 09fd22e

Browse files
committed
Fix
1 parent b0b3cd6 commit 09fd22e

File tree

4 files changed

+61
-66
lines changed

4 files changed

+61
-66
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -60,30 +60,21 @@ private void grow(int neededSize) {
6060
}
6161
}
6262

63+
private int totalSize() {
64+
return cursor - Platform.BYTE_ARRAY_OFFSET;
65+
}
66+
6367
public void append(UTF8String value) {
6468
grow(value.numBytes());
6569
value.writeToMemory(buffer, cursor);
6670
cursor += value.numBytes();
6771
}
6872

6973
public void append(String value) {
70-
append(value.getBytes(StandardCharsets.UTF_8));
71-
}
72-
73-
public void append(byte[] value) {
74-
grow(value.length);
75-
Platform.copyMemory(value, Platform.BYTE_ARRAY_OFFSET, buffer, cursor, value.length);
76-
cursor += value.length;
77-
}
78-
79-
public UTF8String toUTF8String() {
80-
final int len = totalSize();
81-
final byte[] bytes = new byte[len];
82-
Platform.copyMemory(buffer, Platform.BYTE_ARRAY_OFFSET, bytes, Platform.BYTE_ARRAY_OFFSET, len);
83-
return UTF8String.fromBytes(bytes);
74+
append(UTF8String.fromString(value));
8475
}
8576

86-
public int totalSize() {
87-
return cursor - Platform.BYTE_ARRAY_OFFSET;
77+
public UTF8String build() {
78+
return UTF8String.fromBytes(buffer, 0, totalSize());
8879
}
8980
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 37 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -206,22 +206,27 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
206206
case DateType => buildCast[Int](_, d => UTF8String.fromString(DateTimeUtils.dateToString(d)))
207207
case TimestampType => buildCast[Long](_,
208208
t => UTF8String.fromString(DateTimeUtils.timestampToString(t, timeZone)))
209-
case ar: ArrayType =>
209+
case ArrayType(et, _) =>
210210
buildCast[ArrayData](_, array => {
211-
val res = new UTF8StringBuilder
212-
res.append("[")
211+
val builder = new UTF8StringBuilder
212+
builder.append("[")
213213
if (array.numElements > 0) {
214-
val toUTF8String = castToString(ar.elementType)
215-
res.append(toUTF8String(array.get(0, ar.elementType)).asInstanceOf[UTF8String])
214+
val toUTF8String = castToString(et)
215+
if (!array.isNullAt(0)) {
216+
builder.append(toUTF8String(array.get(0, et)).asInstanceOf[UTF8String])
217+
}
216218
var i = 1
217219
while (i < array.numElements) {
218-
res.append(", ")
219-
res.append(toUTF8String(array.get(i, ar.elementType)).asInstanceOf[UTF8String])
220+
builder.append(",")
221+
if (!array.isNullAt(i)) {
222+
builder.append(" ")
223+
builder.append(toUTF8String(array.get(i, et)).asInstanceOf[UTF8String])
224+
}
220225
i += 1
221226
}
222227
}
223-
res.append("]")
224-
res.toUTF8String
228+
builder.append("]")
229+
builder.build()
225230
})
226231
case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString))
227232
}
@@ -614,45 +619,37 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
614619
"""
615620
}
616621

617-
private[this] def writeElemToBufferCode(
618-
dataType: DataType,
619-
buffer: String,
620-
elemTerm: String,
621-
ctx: CodegenContext): String = dataType match {
622-
case BinaryType | StringType => s"$buffer.append($elemTerm)"
623-
case DateType => s"""$buffer.append(
624-
org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($elemTerm))"""
625-
case TimestampType => s"""$buffer.append(
626-
org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($elemTerm))"""
627-
case ar: ArrayType => s"${codegenWriteArrayToBuffer(ar, ctx)}($elemTerm, $buffer)"
628-
case _ => s"$buffer.append(String.valueOf($elemTerm))"
629-
}
622+
private[this] def codegenWriteArrayElemCode(et: DataType, ctx: CodegenContext): String = {
623+
val elementToStringCode = castToStringCode(et, ctx)
624+
val funcName = ctx.freshName("elementToString")
625+
val elementToStringFunc = ctx.addNewFunction(funcName,
626+
s"""
627+
|private UTF8String $funcName(${ctx.javaType(et)} element) {
628+
| UTF8String elementStr = null;
629+
| ${elementToStringCode("element", "elementStr", null /* resultIsNull won't be used */)}
630+
| return elementStr;
631+
|}
632+
""".stripMargin)
630633

631-
private[this] def codegenWriteArrayToBuffer(ar: ArrayType, ctx: CodegenContext): String = {
632634
val loopIndex = ctx.freshName("loopIndex")
633635
val writeArrayToBuffer = ctx.freshName("writeArrayToBuffer")
634636
val arTerm = ctx.freshName("arTerm")
635637
val bufferClass = classOf[UTF8StringBuilder].getName
636638
val bufferTerm = ctx.freshName("bufferTerm")
637-
def writeElemCode(elemTerm: String) = {
638-
writeElemToBufferCode(ar.elementType, bufferTerm, elemTerm, ctx)
639-
}
640-
def writeToBufferCode(i: String) = {
641-
val elemTerm = ctx.freshName("elemTerm")
642-
s"""
643-
|${ctx.javaType(ar.elementType)} $elemTerm = ${ctx.getValue(arTerm, ar.elementType, i)};
644-
|${writeElemCode(elemTerm)};
645-
""".stripMargin
646-
}
647639
ctx.addNewFunction(writeArrayToBuffer,
648640
s"""
649641
|private void $writeArrayToBuffer(ArrayData $arTerm, $bufferClass $bufferTerm) {
650642
| $bufferTerm.append("[");
651643
| if ($arTerm.numElements() > 0) {
652-
| ${writeToBufferCode("0")}
644+
| if (!$arTerm.isNullAt(0)) {
645+
| $bufferTerm.append($elementToStringFunc(${ctx.getValue(arTerm, et, "0")}));
646+
| }
653647
| for (int $loopIndex = 1; $loopIndex < $arTerm.numElements(); $loopIndex++) {
654-
| $bufferTerm.append(", ");
655-
| ${writeToBufferCode(loopIndex)}
648+
| $bufferTerm.append(",");
649+
| if (!$arTerm.isNullAt($loopIndex)) {
650+
| $bufferTerm.append(" ");
651+
| $bufferTerm.append($elementToStringFunc(${ctx.getValue(arTerm, et, loopIndex)}));
652+
| }
656653
| }
657654
| }
658655
| $bufferTerm.append("]");
@@ -671,15 +668,15 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
671668
val tz = ctx.addReferenceObj("timeZone", timeZone)
672669
(c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString(
673670
org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c, $tz));"""
674-
case ar: ArrayType =>
671+
case ArrayType(et, _) =>
675672
(c, evPrim, evNull) => {
676673
val bufferTerm = ctx.freshName("bufferTerm")
677674
val bufferClass = classOf[UTF8StringBuilder].getName
678-
val writeArrayToBuffer = codegenWriteArrayToBuffer(ar, ctx)
675+
val writeArrayElemCode = codegenWriteArrayElemCode(et, ctx)
679676
s"""
680677
|$bufferClass $bufferTerm = new $bufferClass();
681-
|$writeArrayToBuffer($c, $bufferTerm);
682-
|$evPrim = $bufferTerm.toUTF8String();
678+
|$writeArrayElemCode($c, $bufferTerm);
679+
|$evPrim = $bufferTerm.build();
683680
""".stripMargin
684681
}
685682
case _ =>

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -859,21 +859,23 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
859859
checkEvaluation(ret1, "[1, 2, 3, 4, 5]")
860860
val ret2 = cast(Literal.create(Array("ab", "cde", "f")), StringType)
861861
checkEvaluation(ret2, "[ab, cde, f]")
862-
val ret3 = cast(Literal.create(Array("ab".getBytes, "cde".getBytes, "f".getBytes)), StringType)
863-
checkEvaluation(ret3, "[ab, cde, f]")
864-
val ret4 = cast(
862+
val ret3 = cast(Literal.create(Array("ab", null, "c")), StringType)
863+
checkEvaluation(ret3, "[ab,, c]")
864+
val ret4 = cast(Literal.create(Array("ab".getBytes, "cde".getBytes, "f".getBytes)), StringType)
865+
checkEvaluation(ret4, "[ab, cde, f]")
866+
val ret5 = cast(
865867
Literal.create(Array("2014-12-03", "2014-12-04", "2014-12-06").map(Date.valueOf)),
866868
StringType)
867-
checkEvaluation(ret4, "[2014-12-03, 2014-12-04, 2014-12-06]")
868-
val ret5 = cast(
869+
checkEvaluation(ret5, "[2014-12-03, 2014-12-04, 2014-12-06]")
870+
val ret6 = cast(
869871
Literal.create(Array("2014-12-03 13:01:00", "2014-12-04 15:05:00").map(Timestamp.valueOf)),
870872
StringType)
871-
checkEvaluation(ret5, "[2014-12-03 13:01:00, 2014-12-04 15:05:00]")
872-
val ret6 = cast(Literal.create(Array(Array(1, 2, 3), Array(4, 5))), StringType)
873-
checkEvaluation(ret6, "[[1, 2, 3], [4, 5]]")
874-
val ret7 = cast(
873+
checkEvaluation(ret6, "[2014-12-03 13:01:00, 2014-12-04 15:05:00]")
874+
val ret7 = cast(Literal.create(Array(Array(1, 2, 3), Array(4, 5))), StringType)
875+
checkEvaluation(ret7, "[[1, 2, 3], [4, 5]]")
876+
val ret8 = cast(
875877
Literal.create(Array(Array(Array("a"), Array("b", "c")), Array(Array("d")))),
876878
StringType)
877-
checkEvaluation(ret7, "[[[a], [b, c]], [[d]]]")
879+
checkEvaluation(ret8, "[[[a], [b, c]], [[d]]]")
878880
}
879881
}

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2787,6 +2787,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
27872787
val df = sql("SELECT CAST(a AS STRING) FROM t")
27882788
checkAnswer(df, Row("[ab, cde, f]"))
27892789
}
2790+
withTable("t") {
2791+
Seq(Seq("ab", null, "c")).toDF("a").write.saveAsTable("t")
2792+
val df = sql("SELECT CAST(a AS STRING) FROM t")
2793+
checkAnswer(df, Row("[ab,, c]"))
2794+
}
27902795
withTable("t") {
27912796
Seq(Seq("ab".getBytes, "cde".getBytes, "f".getBytes)).toDF("a").write.saveAsTable("t")
27922797
val df = sql("SELECT CAST(a AS STRING) FROM t")

0 commit comments

Comments
 (0)