Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression)
/**
* Returns the input formatted according do printf-style format strings
*/
case class StringFormat(children: Expression*) extends Expression with CodegenFallback {
case class StringFormat(children: Expression*) extends Expression with ImplicitCastInputTypes {

require(children.nonEmpty, "printf() should take at least 1 argument")

Expand All @@ -486,6 +486,10 @@ case class StringFormat(children: Expression*) extends Expression with CodegenFa
private def format: Expression = children(0)
private def args: Seq[Expression] = children.tail

override def inputTypes: Seq[AbstractDataType] =
children.zipWithIndex.map(x => if (x._2 == 0) StringType else AnyDataType)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@marmbrus Is this what you proposed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works. I think StringType :: List.fill(children.size - 1)(AnyDataType) might be a little more clear.



override def eval(input: InternalRow): Any = {
val pattern = format.eval(input)
if (pattern == null) {
Expand All @@ -501,6 +505,42 @@ case class StringFormat(children: Expression*) extends Expression with CodegenFa
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val pattern = children.head.gen(ctx)

val argListGen = children.tail.map(x => (x.dataType, x.gen(ctx)))
val argListCode = argListGen.map(_._2.code + "\n")

val argListString = argListGen.foldLeft("")((s, v) => {
val nullSafeString =
if (ctx.boxedType(v._1) != ctx.javaType(v._1)) {
// Java primitives get boxed in order to allow null values.
s"(${v._2.isNull}) ? (${ctx.boxedType(v._1)}) null : " +
s"new ${ctx.boxedType(v._1)}(${v._2.primitive})"
} else {
s"(${v._2.isNull}) ? null : ${v._2.primitive}"
}
s + "," + nullSafeString
})

val form = ctx.freshName("formatter")
val formatter = classOf[java.util.Formatter].getName
val sb = ctx.freshName("sb")
val stringBuffer = classOf[StringBuffer].getName
s"""
${pattern.code}
boolean ${ev.isNull} = ${pattern.isNull};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${argListCode.mkString}
$stringBuffer $sb = new $stringBuffer();
$formatter $form = new $formatter($sb, ${classOf[Locale].getName}.US);
$form.format(${pattern.primitive}.toString() $argListString);
${ev.primitive} = UTF8String.fromString($sb.toString());
}
"""
}

override def prettyName: String = "printf"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,18 +351,16 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}

test("FORMAT") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

existing: would you mind rewriting these to avoid the use of row and just use literals? using a row makes the test cases harder to follow since you have to look in multiple places to understand what is going on.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for cleaning this up!

val f = 'f.string.at(0)
val d1 = 'd.int.at(1)
val s1 = 's.int.at(2)

val row1 = create_row("aa%d%s", 12, "cc")
val row2 = create_row(null, 12, "cc")
checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1)
checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a")
checkEvaluation(StringFormat(Literal("aa")), "aa", create_row(null))
checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1)
checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a")
checkEvaluation(StringFormat(Literal("aa%d%s"), 12, "cc"), "aa12cc")

checkEvaluation(StringFormat(f, d1, s1), "aa12cc", row1)
checkEvaluation(StringFormat(f, d1, s1), null, row2)
checkEvaluation(StringFormat(Literal.create(null, StringType), 12, "cc"), null)
checkEvaluation(
StringFormat(Literal("aa%d%s"), Literal.create(null, IntegerType), "cc"), "aanullcc")
checkEvaluation(
StringFormat(Literal("aa%d%s"), 12, Literal.create(null, StringType)), "aa12null")
}

test("INSTR") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,16 @@ class StringFunctionsSuite extends QueryTest {
checkAnswer(
df.selectExpr("printf(a, b, c)"),
Row("aa123cc"))

val df2 = Seq(("aa%d%s".getBytes, 123, "cc")).toDF("a", "b", "c")

checkAnswer(
df2.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")),
Row("aa123cc", "aa123cc"))

checkAnswer(
df2.selectExpr("printf(a, b, c)"),
Row("aa123cc"))
}

test("string instr function") {
Expand Down