From ab7cf32e34acee2101ccccb39d42ec40282b7539 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 21 Dec 2015 18:21:34 +0800 Subject: [PATCH] Add UserDefinedType support to Cast. --- .../spark/sql/catalyst/expressions/Cast.scala | 14 +++++ .../sql/catalyst/expressions/CastSuite.scala | 51 +++++++++++++++++++ 2 files changed, 65 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index b18f49f3203f..d82d3edae4e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.math.{BigDecimal => JavaBigDecimal} +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -81,6 +82,9 @@ object Cast { toField.nullable) } + case (udt1: UserDefinedType[_], udt2: UserDefinedType[_]) if udt1.userClass == udt2.userClass => + true + case _ => false } @@ -431,6 +435,11 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { case array: ArrayType => castArray(from.asInstanceOf[ArrayType].elementType, array.elementType) case map: MapType => castMap(from.asInstanceOf[MapType], map) case struct: StructType => castStruct(from.asInstanceOf[StructType], struct) + case udt: UserDefinedType[_] + if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass => + identity[Any] + case _: UserDefinedType[_] => + throw new SparkException(s"Cannot cast $from to $to.") } private[this] lazy val cast: Any => Any = cast(child.dataType, dataType) @@ -473,6 +482,11 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { castArrayCode(from.asInstanceOf[ArrayType].elementType, array.elementType, ctx) case map: MapType => castMapCode(from.asInstanceOf[MapType], map, ctx) case struct: StructType => castStructCode(from.asInstanceOf[StructType], struct, ctx) + case udt: UserDefinedType[_] + if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass => + (c, evPrim, evNull) => s"$evPrim = $c;" + case _: UserDefinedType[_] => + throw new SparkException(s"Cannot cast $from to $to.") } // Since we need to cast child expressions recursively inside ComplexTypes, such as Map's diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index c99a4ac9645a..8483548159ec 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Timestamp, Date} import java.util.{TimeZone, Calendar} +import org.scalatest.exceptions.TestFailedException + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow @@ -27,6 +29,34 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +@SQLUserDefinedType(udt = classOf[ExampleUDT]) +class ExampleClass extends Serializable + +@SQLUserDefinedType(udt = classOf[Example2UDT]) +class ExampleClass2 extends Serializable + +class ExampleUDT extends UserDefinedType[ExampleClass] { + override def sqlType: DataType = DoubleType + override def pyUDT: String = "pyspark.sql.test.ExampleUDT" + override def serialize(obj: Any): Double = { + 0.0 + } + override def deserialize(datum: Any): ExampleClass = new ExampleClass + override def userClass: Class[ExampleClass] = classOf[ExampleClass] + private[spark] override def asNullable: ExampleUDT = this +} + +class Example2UDT extends UserDefinedType[ExampleClass2] { + override def sqlType: DataType = DoubleType + override def pyUDT: String = "pyspark.sql.test.Example2UDT" + override def serialize(obj: Any): Double = { + 0.0 + } + override def deserialize(datum: Any): ExampleClass2 = new ExampleClass2 + override def userClass: Class[ExampleClass2] = classOf[ExampleClass2] + private[spark] override def asNullable: Example2UDT = this +} + /** * Test suite for data type casting expression [[Cast]]. */ @@ -790,4 +820,25 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast("abc", BooleanType), null) checkEvaluation(cast("", BooleanType), null) } + + test("cast between UserDefinedTypes") { + assert(Cast.canCast(new ExampleUDT, new ExampleUDT) == true) + assert(Cast.canCast(new ExampleUDT, new Example2UDT) == false) + + val udt = new ExampleClass + + val castExpression = Cast(Literal.create(udt, new ExampleUDT), new ExampleUDT) + checkEvaluationWithoutCodegen(castExpression, udt) + val castExpression2 = Cast(Literal.create(0.0, new ExampleUDT), new ExampleUDT) + checkEvaluationWithGeneratedMutableProjection(castExpression2, 0.0) + + val castExpression3 = Cast(Literal.create(udt, new ExampleUDT), new Example2UDT) + intercept[TestFailedException] { + checkEvaluationWithoutCodegen(castExpression3, udt) + } + val castExpression4 = Cast(Literal.create(0.0, new ExampleUDT), new Example2UDT) + intercept[TestFailedException] { + checkEvaluationWithGeneratedMutableProjection(castExpression4, 0.0) + } + } }