Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression)
/**
* Returns the input formatted according do printf-style format strings
*/
case class StringFormat(children: Expression*) extends Expression with CodegenFallback {
case class StringFormat(children: Expression*) extends Expression with ImplicitCastInputTypes {

require(children.nonEmpty, "printf() should take at least 1 argument")

Expand All @@ -536,6 +536,10 @@ case class StringFormat(children: Expression*) extends Expression with CodegenFa
private def format: Expression = children(0)
private def args: Seq[Expression] = children.tail

override def inputTypes: Seq[AbstractDataType] =
StringType :: List.fill(children.size - 1)(AnyDataType)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated this as well.



override def eval(input: InternalRow): Any = {
val pattern = format.eval(input)
if (pattern == null) {
Expand All @@ -551,6 +555,42 @@ case class StringFormat(children: Expression*) extends Expression with CodegenFa
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val pattern = children.head.gen(ctx)

val argListGen = children.tail.map(x => (x.dataType, x.gen(ctx)))
val argListCode = argListGen.map(_._2.code + "\n")

val argListString = argListGen.foldLeft("")((s, v) => {
val nullSafeString =
if (ctx.boxedType(v._1) != ctx.javaType(v._1)) {
// Java primitives get boxed in order to allow null values.
s"(${v._2.isNull}) ? (${ctx.boxedType(v._1)}) null : " +
s"new ${ctx.boxedType(v._1)}(${v._2.primitive})"
} else {
s"(${v._2.isNull}) ? null : ${v._2.primitive}"
}
s + "," + nullSafeString
})

val form = ctx.freshName("formatter")
val formatter = classOf[java.util.Formatter].getName
val sb = ctx.freshName("sb")
val stringBuffer = classOf[StringBuffer].getName
s"""
${pattern.code}
boolean ${ev.isNull} = ${pattern.isNull};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${argListCode.mkString}
$stringBuffer $sb = new $stringBuffer();
$formatter $form = new $formatter($sb, ${classOf[Locale].getName}.US);
$form.format(${pattern.primitive}.toString() $argListString);
${ev.primitive} = UTF8String.fromString($sb.toString());
}
"""
}

override def prettyName: String = "printf"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,18 +351,16 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}

test("FORMAT") {
val f = 'f.string.at(0)
val d1 = 'd.int.at(1)
val s1 = 's.int.at(2)

val row1 = create_row("aa%d%s", 12, "cc")
val row2 = create_row(null, 12, "cc")
checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1)
checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a")
checkEvaluation(StringFormat(Literal("aa")), "aa", create_row(null))
checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1)
checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a")
checkEvaluation(StringFormat(Literal("aa%d%s"), 12, "cc"), "aa12cc")

checkEvaluation(StringFormat(f, d1, s1), "aa12cc", row1)
checkEvaluation(StringFormat(f, d1, s1), null, row2)
checkEvaluation(StringFormat(Literal.create(null, StringType), 12, "cc"), null)
checkEvaluation(
StringFormat(Literal("aa%d%s"), Literal.create(null, IntegerType), "cc"), "aanullcc")
checkEvaluation(
StringFormat(Literal("aa%d%s"), 12, Literal.create(null, StringType)), "aa12null")
}

test("INSTR") {
Expand Down
11 changes: 11 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1741,6 +1741,17 @@ object functions {
*/
def rtrim(e: Column): Column = StringTrimRight(e.expr)

/**
* Format strings in printf-style.
*
* @group string_funcs
* @since 1.5.0
*/
@scala.annotation.varargs
def formatString(format: Column, arguments: Column*): Column = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i was thinking about a a version in which format is a string, but rest of the arguments are columns.

StringFormat((format +: arguments).map(_.expr): _*)
}

/**
* Format strings in printf-style.
* NOTE: `format` is the string value of the formatter, not column name.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,16 @@ class StringFunctionsSuite extends QueryTest {
checkAnswer(
df.selectExpr("printf(a, b, c)"),
Row("aa123cc"))

val df2 = Seq(("aa%d%s".getBytes, 123, "cc")).toDF("a", "b", "c")

checkAnswer(
df2.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")),
Row("aa123cc", "aa123cc"))

checkAnswer(
df2.selectExpr("printf(a, b, c)"),
Row("aa123cc"))
}

test("string instr function") {
Expand Down