Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import java.math.{BigDecimal => JavaBigDecimal}

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
Expand Down Expand Up @@ -81,6 +82,9 @@ object Cast {
toField.nullable)
}

case (udt1: UserDefinedType[_], udt2: UserDefinedType[_]) if udt1.userClass == udt2.userClass =>
true

case _ => false
}

Expand Down Expand Up @@ -431,6 +435,11 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
case array: ArrayType => castArray(from.asInstanceOf[ArrayType].elementType, array.elementType)
case map: MapType => castMap(from.asInstanceOf[MapType], map)
case struct: StructType => castStruct(from.asInstanceOf[StructType], struct)
case udt: UserDefinedType[_]
if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass =>
identity[Any]
case _: UserDefinedType[_] =>
throw new SparkException(s"Cannot cast $from to $to.")
}

private[this] lazy val cast: Any => Any = cast(child.dataType, dataType)
Expand Down Expand Up @@ -473,6 +482,11 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
castArrayCode(from.asInstanceOf[ArrayType].elementType, array.elementType, ctx)
case map: MapType => castMapCode(from.asInstanceOf[MapType], map, ctx)
case struct: StructType => castStructCode(from.asInstanceOf[StructType], struct, ctx)
case udt: UserDefinedType[_]
if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass =>
(c, evPrim, evNull) => s"$evPrim = $c;"
case _: UserDefinedType[_] =>
throw new SparkException(s"Cannot cast $from to $to.")
}

// Since we need to cast child expressions recursively inside ComplexTypes, such as Map's
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,43 @@ package org.apache.spark.sql.catalyst.expressions
import java.sql.{Timestamp, Date}
import java.util.{TimeZone, Calendar}

import org.scalatest.exceptions.TestFailedException

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

@SQLUserDefinedType(udt = classOf[ExampleUDT])
class ExampleClass extends Serializable

@SQLUserDefinedType(udt = classOf[Example2UDT])
class ExampleClass2 extends Serializable

class ExampleUDT extends UserDefinedType[ExampleClass] {
override def sqlType: DataType = DoubleType
override def pyUDT: String = "pyspark.sql.test.ExampleUDT"
override def serialize(obj: Any): Double = {
0.0
}
override def deserialize(datum: Any): ExampleClass = new ExampleClass
override def userClass: Class[ExampleClass] = classOf[ExampleClass]
private[spark] override def asNullable: ExampleUDT = this
}

class Example2UDT extends UserDefinedType[ExampleClass2] {
override def sqlType: DataType = DoubleType
override def pyUDT: String = "pyspark.sql.test.Example2UDT"
override def serialize(obj: Any): Double = {
0.0
}
override def deserialize(datum: Any): ExampleClass2 = new ExampleClass2
override def userClass: Class[ExampleClass2] = classOf[ExampleClass2]
private[spark] override def asNullable: Example2UDT = this
}

/**
* Test suite for data type casting expression [[Cast]].
*/
Expand Down Expand Up @@ -790,4 +820,25 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(cast("abc", BooleanType), null)
checkEvaluation(cast("", BooleanType), null)
}

test("cast between UserDefinedTypes") {
assert(Cast.canCast(new ExampleUDT, new ExampleUDT) == true)
assert(Cast.canCast(new ExampleUDT, new Example2UDT) == false)

val udt = new ExampleClass

val castExpression = Cast(Literal.create(udt, new ExampleUDT), new ExampleUDT)
checkEvaluationWithoutCodegen(castExpression, udt)
val castExpression2 = Cast(Literal.create(0.0, new ExampleUDT), new ExampleUDT)
checkEvaluationWithGeneratedMutableProjection(castExpression2, 0.0)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For codgen case, the internal data for ExampleUDT here is double. Because of that, we cannot simply call checkEvaluation in this test.


val castExpression3 = Cast(Literal.create(udt, new ExampleUDT), new Example2UDT)
intercept[TestFailedException] {
checkEvaluationWithoutCodegen(castExpression3, udt)
}
val castExpression4 = Cast(Literal.create(0.0, new ExampleUDT), new Example2UDT)
intercept[TestFailedException] {
checkEvaluationWithGeneratedMutableProjection(castExpression4, 0.0)
}
}
}