Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
builder.append("]")
builder.build()
})
case udt: UserDefinedType[_] =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the only place we miss UDT in Cast?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about cast UDT to other types?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I'll check

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can strip UDT at L608 and fix this problem entirely.

Copy link
Member Author

@maropu maropu Jan 13, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here?

protected override def nullSafeEval(input: Any): Any = cast(input)

Since we can cast UDTs to other UDTs or strings only (by checking Cast.canCast), IIUC there is only a place to fix?

def canCast(from: DataType, to: DataType): Boolean = (from, to) match {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see, makes sense

buildCast[Any](_, o => UTF8String.fromString(udt.deserialize(o).toString))
case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString))
}

Expand Down Expand Up @@ -836,6 +838,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
|$evPrim = $buffer.build();
""".stripMargin
}
case udt: UserDefinedType[_] =>
val udtRef = ctx.addReferenceObj("udt", udt)
(c, evPrim, evNull) => {
s"$evPrim = UTF8String.fromString($udtRef.deserialize($c).toString());"
}
case _ =>
(c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.beans.{BeanInfo, BeanProperty}

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.catalyst.expressions.{Cast, ExpressionEvalHelper, GenericInternalRow, Literal}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.execution.datasources.parquet.ParquetTest
import org.apache.spark.sql.functions._
Expand All @@ -44,6 +44,8 @@ object UDT {
case v: MyDenseVector => java.util.Arrays.equals(this.data, v.data)
case _ => false
}

override def toString: String = data.mkString("(", ", ", ")")
}

private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] {
Expand Down Expand Up @@ -143,7 +145,8 @@ private[spark] class ExampleSubTypeUDT extends UserDefinedType[IExampleSubType]
override def userClass: Class[IExampleSubType] = classOf[IExampleSubType]
}

class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest {
class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest
with ExpressionEvalHelper {
import testImplicits._

private lazy val pointsRDD = Seq(
Expand Down Expand Up @@ -304,4 +307,12 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT
pointsRDD.except(pointsRDD2),
Seq(Row(0.0, new UDT.MyDenseVector(Array(0.2, 2.0)))))
}

test("SPARK-23054 Cast UserDefinedType to string") {
val udt = new UDT.MyDenseVectorUDT()
val vector = new UDT.MyDenseVector(Array(1.0, 3.0, 5.0, 7.0, 9.0))
val data = udt.serialize(vector)
val ret = Cast(Literal(data, udt), StringType, None)
checkEvaluation(ret, "(1.0, 3.0, 5.0, 7.0, 9.0)")
}
}