Skip to content

Commit 086caba

Browse files
committed
[SPARK-9154][SQL] codegen string format
1 parent 5bdf16d commit 086caba

File tree

2 files changed

+25
-0
lines changed

2 files changed

+25
-0
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,30 @@ case class StringFormat(children: Expression*) extends Expression with CodegenFa
501501
}
502502
}
503503

504+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
505+
val pattern = children.head.gen(ctx)
506+
val argListGen = children.tail.map(x => (x.dataType, x.gen(ctx)))
507+
val argListCode = argListGen.map(_._2.code + "\n")
508+
val argListString = argListGen.foldLeft("")((s, x) => s + s", ${x._2.primitive}" + (if (!ctx.isPrimitiveType(x._1)) ".toString()" else ""))
509+
val form = ctx.freshName("formatter")
510+
val formatter = classOf[java.util.Formatter].getName
511+
val sb = ctx.freshName("sb")
512+
val stringBuffer = classOf[StringBuffer].getName
513+
514+
s"""
515+
${pattern.code}
516+
boolean ${ev.isNull} = ${pattern.isNull};
517+
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
518+
if (!${ev.isNull}) {
519+
${argListCode.mkString}
520+
$stringBuffer $sb = new $stringBuffer();
521+
$formatter $form = new $formatter($sb, ${classOf[Locale].getName}.US);
522+
$form.format(${pattern.primitive}.toString() $argListString);
523+
${ev.primitive} = UTF8String.fromString($sb.toString());
524+
}
525+
"""
526+
}
527+
504528
override def prettyName: String = "printf"
505529
}
506530

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
361361
checkEvaluation(StringFormat(Literal("aa")), "aa", create_row(null))
362362
checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1)
363363

364+
println(StringFormat(f, d1, s1).eval(row1))
364365
checkEvaluation(StringFormat(f, d1, s1), "aa12cc", row1)
365366
checkEvaluation(StringFormat(f, d1, s1), null, row2)
366367
}

0 commit comments

Comments
 (0)