Skip to content

Commit b5b5e35

Browse files
committed
Add UTF8StringBuilder
1 parent 99c3ed0 commit b5b5e35

File tree

2 files changed

+91
-11
lines changed

2 files changed

+91
-11
lines changed
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions.codegen;
19+
20+
import java.nio.charset.StandardCharsets;
21+
22+
import org.apache.spark.unsafe.types.UTF8String;
23+
24+
/**
25+
* A helper class to write `UTF8String`, `String`, and `byte[]` data into an internal buffer
26+
* and get a final concatenated string.
27+
*/
28+
public class UTF8StringBuilder {
29+
30+
private StringBuilder buffer;
31+
32+
public UTF8StringBuilder() {
33+
this.buffer = new StringBuilder();
34+
}
35+
36+
public void append(UTF8String value) {
37+
buffer.append(value);
38+
}
39+
40+
public void append(String value) {
41+
buffer.append(value);
42+
}
43+
44+
public void append(byte[] value) {
45+
buffer.append(new String(value, StandardCharsets.UTF_8));
46+
}
47+
48+
@Override
49+
public String toString() {
50+
return buffer.toString();
51+
}
52+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -199,30 +199,59 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
199199

200200
// [[func]] assumes the input is no longer null because eval already does the null check.
201201
@inline private[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T])
202+
@inline private[this] def buildWriter[T](
203+
a: Any, buffer: UTF8StringBuilder, writer: (T, UTF8StringBuilder) => Unit): Unit = {
204+
writer(a.asInstanceOf[T], buffer)
205+
}
206+
207+
private[this] def buildElemWriter(
208+
from: DataType): (Any, UTF8StringBuilder) => Unit = from match {
209+
case BinaryType => buildWriter[Array[Byte]](_, _, (b, buf) => buf.append(b))
210+
case StringType => buildWriter[UTF8String](_, _, (b, buf) => buf.append(b))
211+
case DateType => buildWriter[Int](_, _,
212+
(d, buf) => buf.append(DateTimeUtils.dateToString(d)))
213+
case TimestampType => buildWriter[Long](_, _,
214+
(t, buf) => buf.append(DateTimeUtils.timestampToString(t)))
215+
case ar: ArrayType =>
216+
buildWriter[ArrayData](_, _, (array, buf) => {
217+
buf.append("[")
218+
if (array.numElements > 0) {
219+
val writeElemToBuffer = buildElemWriter(ar.elementType)
220+
writeElemToBuffer(array.get(0, ar.elementType), buf)
221+
var i = 1
222+
while (i < array.numElements) {
223+
buf.append(", ")
224+
writeElemToBuffer(array.get(i, ar.elementType), buf)
225+
i += 1
226+
}
227+
}
228+
buf.append("]")
229+
})
230+
case _ => buildWriter[Any](_, _, (o, buf) => buf.append(String.valueOf(o)))
231+
}
202232

203233
// UDFToString
204234
private[this] def castToString(from: DataType): Any => Any = from match {
205235
case BinaryType => buildCast[Array[Byte]](_, UTF8String.fromBytes)
206-
case StringType => buildCast[UTF8String](_, identity)
207236
case DateType => buildCast[Int](_, d => UTF8String.fromString(DateTimeUtils.dateToString(d)))
208237
case TimestampType => buildCast[Long](_,
209238
t => UTF8String.fromString(DateTimeUtils.timestampToString(t, timeZone)))
210239
case ar: ArrayType =>
211240
buildCast[ArrayData](_, array => {
212-
val res = new StringBuilder
241+
val res = new UTF8StringBuilder
213242
res.append("[")
214243
if (array.numElements > 0) {
215-
val toStringFunc = castToString(ar.elementType)
216-
res.append(toStringFunc(array.get(0, ar.elementType)))
244+
val writeElemToBuffer = buildElemWriter(ar.elementType)
245+
writeElemToBuffer(array.get(0, ar.elementType), res)
217246
var i = 1
218247
while (i < array.numElements) {
219248
res.append(", ")
220-
res.append(toStringFunc(array.get(i, ar.elementType)))
249+
writeElemToBuffer(array.get(i, ar.elementType), res)
221250
i += 1
222251
}
223252
}
224253
res.append("]")
225-
UTF8String.fromString(res.toString())
254+
UTF8String.fromString(res.toString)
226255
})
227256
case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString))
228257
}
@@ -620,21 +649,20 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
620649
buffer: String,
621650
elemTerm: String,
622651
ctx: CodegenContext): String = dataType match {
623-
case BinaryType => s"$buffer.append(new String($elemTerm))"
624-
case StringType => s"$buffer.append(new String($elemTerm.getBytes()))"
652+
case BinaryType | StringType => s"$buffer.append($elemTerm)"
625653
case DateType => s"""$buffer.append(
626654
org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($elemTerm))"""
627655
case TimestampType => s"""$buffer.append(
628656
org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($elemTerm))"""
629657
case ar: ArrayType => s"${codegenWriteArrayToBuffer(ar, ctx)}($elemTerm, $buffer)"
630-
case _ => s"$buffer.append($elemTerm)"
658+
case _ => s"$buffer.append(String.valueOf($elemTerm))"
631659
}
632660

633661
private[this] def codegenWriteArrayToBuffer(ar: ArrayType, ctx: CodegenContext): String = {
634662
val loopIndex = ctx.freshName("loopIndex")
635663
val writeArrayToBuffer = ctx.freshName("writeArrayToBuffer")
636664
val arTerm = ctx.freshName("arTerm")
637-
val bufferClass = "java.lang.StringBuilder"
665+
val bufferClass = classOf[UTF8StringBuilder].getName
638666
val bufferTerm = ctx.freshName("bufferTerm")
639667
def writeElemCode(elemTerm: String) = {
640668
writeElemToBufferCode(ar.elementType, bufferTerm, elemTerm, ctx)
@@ -676,7 +704,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
676704
case ar: ArrayType =>
677705
(c, evPrim, evNull) => {
678706
val bufferTerm = ctx.freshName("bufferTerm")
679-
val bufferClass = "java.lang.StringBuilder"
707+
val bufferClass = classOf[UTF8StringBuilder].getName
680708
val writeArrayToBuffer = codegenWriteArrayToBuffer(ar, ctx)
681709
s"""
682710
|$bufferClass $bufferTerm = new $bufferClass();

0 commit comments

Comments
 (0)