Skip to content

Commit 137d85f

Browse files
committed
Cast user-defined data into strings
1 parent 2250cb7 commit 137d85f

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
282282
builder.append("]")
283283
builder.build()
284284
})
285+
case udt: UserDefinedType[_] =>
286+
buildCast[Any](_, o => UTF8String.fromString(udt.deserialize(o).toString))
285287
case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString))
286288
}
287289

@@ -836,6 +838,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
836838
|$evPrim = $buffer.build();
837839
""".stripMargin
838840
}
841+
case udt: UserDefinedType[_] =>
842+
val udtRef = ctx.addReferenceObj("udt", udt)
843+
(c, evPrim, evNull) => {
844+
s"$evPrim = UTF8String.fromString($udtRef.deserialize($c).toString());"
845+
}
839846
case _ =>
840847
(c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));"
841848
}

sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import scala.beans.{BeanInfo, BeanProperty}
2121

2222
import org.apache.spark.rdd.RDD
2323
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
24-
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
24+
import org.apache.spark.sql.catalyst.expressions.{Cast, ExpressionEvalHelper, GenericInternalRow, Literal}
2525
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
2626
import org.apache.spark.sql.execution.datasources.parquet.ParquetTest
2727
import org.apache.spark.sql.functions._
@@ -44,6 +44,8 @@ object UDT {
4444
case v: MyDenseVector => java.util.Arrays.equals(this.data, v.data)
4545
case _ => false
4646
}
47+
48+
override def toString: String = data.mkString("(", ", ", ")")
4749
}
4850

4951
private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] {
@@ -143,7 +145,8 @@ private[spark] class ExampleSubTypeUDT extends UserDefinedType[IExampleSubType]
143145
override def userClass: Class[IExampleSubType] = classOf[IExampleSubType]
144146
}
145147

146-
class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest {
148+
class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest
149+
with ExpressionEvalHelper {
147150
import testImplicits._
148151

149152
private lazy val pointsRDD = Seq(
@@ -304,4 +307,12 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT
304307
pointsRDD.except(pointsRDD2),
305308
Seq(Row(0.0, new UDT.MyDenseVector(Array(0.2, 2.0)))))
306309
}
310+
311+
test("SPARK-23054 Cast UserDefinedType to string") {
312+
val udt = new UDT.MyDenseVectorUDT()
313+
val vector = new UDT.MyDenseVector(Array(1.0, 3.0, 5.0, 7.0, 9.0))
314+
val data = udt.serialize(vector)
315+
val ret = Cast(Literal(data, udt), StringType, None)
316+
checkEvaluation(ret, "(1.0, 3.0, 5.0, 7.0, 9.0)")
317+
}
307318
}

0 commit comments

Comments
 (0)