diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index f2de4c8e30bec..f21aa1e9e3135 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -259,6 +259,29 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String builder.append("]") builder.build() }) + case StructType(fields) => + buildCast[InternalRow](_, row => { + val builder = new UTF8StringBuilder + builder.append("[") + if (row.numFields > 0) { + val st = fields.map(_.dataType) + val toUTF8StringFuncs = st.map(castToString) + if (!row.isNullAt(0)) { + builder.append(toUTF8StringFuncs(0)(row.get(0, st(0))).asInstanceOf[UTF8String]) + } + var i = 1 + while (i < row.numFields) { + builder.append(",") + if (!row.isNullAt(i)) { + builder.append(" ") + builder.append(toUTF8StringFuncs(i)(row.get(i, st(i))).asInstanceOf[UTF8String]) + } + i += 1 + } + } + builder.append("]") + builder.build() + }) case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) } @@ -732,6 +755,41 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String """.stripMargin } + private def writeStructToStringBuilder( + st: Seq[DataType], + row: String, + buffer: String, + ctx: CodegenContext): String = { + val structToStringCode = st.zipWithIndex.map { case (ft, i) => + val fieldToStringCode = castToStringCode(ft, ctx) + val field = ctx.freshName("field") + val fieldStr = ctx.freshName("fieldStr") + s""" + |${if (i != 0) s"""$buffer.append(",");""" else ""} + |if (!$row.isNullAt($i)) { + | ${if (i != 0) s"""$buffer.append(" ");""" else ""} + | + | // Append $i field into the string buffer + | ${ctx.javaType(ft)} $field = ${ctx.getValue(row, ft, s"$i")}; + | UTF8String $fieldStr = null; + | ${fieldToStringCode(field, fieldStr, null /* resultIsNull won't be used */)} + | $buffer.append($fieldStr); + |} + """.stripMargin + } + + val writeStructCode = ctx.splitExpressions( + expressions = structToStringCode, + funcName = "fieldToString", + arguments = ("InternalRow", row) :: (classOf[UTF8StringBuilder].getName, buffer) :: Nil) + + s""" + |$buffer.append("["); + |$writeStructCode + |$buffer.append("]"); + """.stripMargin + } + private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { case BinaryType => @@ -765,6 +823,19 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String |$evPrim = $buffer.build(); """.stripMargin } + case StructType(fields) => + (c, evPrim, evNull) => { + val row = ctx.freshName("row") + val buffer = ctx.freshName("buffer") + val bufferClass = classOf[UTF8StringBuilder].getName + val writeStructCode = writeStructToStringBuilder(fields.map(_.dataType), row, buffer, ctx) + s""" + |InternalRow $row = $c; + |$bufferClass $buffer = new $bufferClass(); + |$writeStructCode + |$evPrim = $buffer.build(); + """.stripMargin + } case _ => (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 1445bb8a97d40..5b25bdf907c3a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -906,4 +906,20 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StringType) checkEvaluation(ret5, "[1 -> [1, 2, 3], 2 -> [4, 5, 6]]") } + + test("SPARK-22981 Cast struct to string") { + val ret1 = cast(Literal.create((1, "a", 0.1)), StringType) + checkEvaluation(ret1, "[1, a, 0.1]") + val ret2 = cast(Literal.create(Tuple3[Int, String, String](1, null, "a")), StringType) + checkEvaluation(ret2, "[1,, a]") + val ret3 = cast(Literal.create( + (Date.valueOf("2014-12-03"), Timestamp.valueOf("2014-12-03 15:05:00"))), StringType) + checkEvaluation(ret3, "[2014-12-03, 2014-12-03 15:05:00]") + val ret4 = cast(Literal.create(((1, "a"), 5, 0.1)), StringType) + checkEvaluation(ret4, "[[1, a], 5, 0.1]") + val ret5 = cast(Literal.create((Seq(1, 2, 3), "a", 0.1)), StringType) + checkEvaluation(ret5, "[[1, 2, 3], a, 0.1]") + val ret6 = cast(Literal.create((1, Map(1 -> "a", 2 -> "b", 3 -> "c"))), StringType) + checkEvaluation(ret6, "[1, [1 -> a, 2 -> b, 3 -> c]]") + } }