From e9dbd7b98cfbf3b6dc00feabcd02be9cee819c00 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 20 Nov 2015 17:23:42 +0800 Subject: [PATCH 1/7] add type cast if the real type is different but compatible with encoder schema --- .../catalyst/encoders/ExpressionEncoder.scala | 58 ++++++++- .../expressions/complexTypeCreator.scala | 2 +- .../encoders/EncoderResolveSuite.scala | 121 ++++++++++++++++++ .../org/apache/spark/sql/DatasetSuite.scala | 7 +- 4 files changed, 180 insertions(+), 8 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolveSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 6eeba1442c1f..6cd3a64d6b2d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -28,9 +28,10 @@ import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtract import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} +import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.types.{StructField, ObjectType, StructType} +import org.apache.spark.sql.types.{DataType, StructField, ObjectType, StructType} /** * A factory for constructing encoders that convert objects and primitives to and from the @@ -210,26 +211,71 @@ case class ExpressionEncoder[T]( }) } + private def handleStruct(input: Expression, s: StructType): Expression = { + assert(input.isInstanceOf[NewInstance] || input.isInstanceOf[CreateExternalRow]) + val children = input.children + assert(children.length == s.length) + + val newChildren = children.zip(s.map(_.dataType)).map { + case (child, dt) => typeCast(child, dt) + } + + input.withNewChildren(newChildren) + } + + private def typeCast(input: Expression, expectedType: DataType): Expression = expectedType match { + case s: StructType => + var continue = true + input transformDown { + case c: CreateExternalRow if continue => + continue = false + handleStruct(c, s) + case n: NewInstance if continue => + continue = false + handleStruct(n, s) + } + + case _ => + var continue = true + input transformDown { + case u: UnresolvedExtractValue if continue => + continue = false + Cast(u, expectedType) + case g: GetInternalRowField if continue => + continue = false + Cast(g, expectedType) + case u: UnresolvedAttribute if continue => + continue = false + Cast(u, expectedType) + case a: AttributeReference if continue => + continue = false + Cast(a, expectedType) + } + } + /** * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the * given schema. */ def resolve( - schema: Seq[Attribute], + attrs: Seq[Attribute], outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = { - val positionToAttribute = AttributeMap.toIndex(schema) - val unbound = fromRowExpression transform { + val positionToAttribute = AttributeMap.toIndex(attrs) + val unbound = fromRowExpression transformUp { case b: BoundReference => positionToAttribute(b.ordinal) } - val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema)) + val withTypeCast = typeCast(unbound, if (flat) schema.head.dataType else schema) + + val plan = Project(Alias(withTypeCast, "")() :: Nil, LocalRelation(attrs)) val analyzedPlan = SimpleAnalyzer.execute(plan) + val optimizedPlan = SimplifyCasts(analyzedPlan) // In order to construct instances of inner classes (for example those declared in a REPL cell), // we need an instance of the outer scope. This rule substitues those outer objects into // expressions that are missing them by looking up the name in the SQLContexts `outerScopes` // registry. - copy(fromRowExpression = analyzedPlan.expressions.head.children.head transform { + copy(fromRowExpression = optimizedPlan.expressions.head.children.head transform { case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass => val outer = outerScopes.get(n.cls.getDeclaringClass.getName) if (outer == null) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 1854dfaa7db3..72cc89c8be91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -126,7 +126,7 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { case class CreateNamedStruct(children: Seq[Expression]) extends Expression { /** - * Returns Aliased [[Expressions]] that could be used to construct a flattened version of this + * Returns Aliased [[Expression]]s that could be used to construct a flattened version of this * StructType. */ def flatten: Seq[NamedExpression] = valExprs.zip(names).map { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolveSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolveSuite.scala new file mode 100644 index 000000000000..b7e3446fc962 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolveSuite.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.types._ + +case class StringLongClass(a: String, b: Long) + +case class ComplexClass(a: Long, b: StringLongClass) + +class EncoderResolveSuite extends PlanTest { + test("real type doesn't match encoder schema but they are compatible: product") { + val encoder = ExpressionEncoder[StringLongClass] + val cls = classOf[StringLongClass] + + var attrs = Seq('a.string, 'b.int) + var fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression + var expected: Expression = NewInstance( + cls, + toExternalString('a.string) :: 'b.int.cast(LongType) :: Nil, + false, + ObjectType(cls)) + compareExpressions(fromRowExpr, expected) + + attrs = Seq('a.int, 'b.long) + fromRowExpr = encoder.resolve(attrs, null).fromRowExpression + expected = NewInstance( + cls, + toExternalString('a.int.cast(StringType)) :: 'b.long :: Nil, + false, + ObjectType(cls)) + compareExpressions(fromRowExpr, expected) + } + + test("real type doesn't match encoder schema but they are compatible: nested product") { + val encoder = ExpressionEncoder[ComplexClass] + val innerCls = classOf[StringLongClass] + val cls = classOf[ComplexClass] + + val structType = new StructType().add("a", IntegerType).add("b", LongType) + val attrs = Seq('a.int, 'b.struct(structType)) + val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression + val expected: Expression = NewInstance( + cls, + Seq( + 'a.int.cast(LongType), + If( + 'b.struct(structType).isNull, + Literal.create(null, ObjectType(innerCls)), + NewInstance( + innerCls, + Seq( + toExternalString(GetStructField( + 'b.struct(structType), + structType(0), + 0).cast(StringType)), + GetStructField( + 'b.struct(structType), + structType(1), + 1)), + false, + ObjectType(innerCls)) + )), + false, + ObjectType(cls)) + compareExpressions(fromRowExpr, expected) + } + + test("real type doesn't match encoder schema but they are compatible: tupled encoder") { + val encoder = ExpressionEncoder.tuple( + ExpressionEncoder[StringLongClass], + ExpressionEncoder[Long]) + val cls = classOf[StringLongClass] + + val structType = new StructType().add("a", StringType).add("b", ByteType, false) + val attrs = Seq('a.struct(structType), 'b.int) + val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression + val expected: Expression = NewInstance( + classOf[Tuple2[_, _]], + Seq( + NewInstance( + cls, + Seq( + toExternalString(GetStructField( + 'a.struct(structType), + structType(0), + 0)), + GetStructField( + 'a.struct(structType), + structType(1), + 1).cast(LongType)), + false, + ObjectType(cls)), + 'b.int.cast(LongType)), + false, + ObjectType(classOf[Tuple2[_, _]])) + compareExpressions(fromRowExpr, expected) + } + + private def toExternalString(e: Expression): Expression = { + Invoke(e, "toString", ObjectType(classOf[String]), Nil) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 89d964aa3e46..322cb1ba05b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -384,7 +384,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { Seq((JavaData(1), 1L), (JavaData(2), 1L))) } - ignore("Java encoder self join") { + test("Java encoder self join") { implicit val kryoEncoder = Encoders.javaSerialization[JavaData] val ds = Seq(JavaData(1), JavaData(2)).toDS() assert(ds.joinWith(ds, lit(true)).collect().toSet == @@ -394,6 +394,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext { (JavaData(2), JavaData(1)), (JavaData(2), JavaData(2)))) } + + test("change encoder with compatible schema") { + val ds = Seq(2 -> 2.toByte, 3 -> 3.toByte).toDF("a", "b").as[ClassData] + assert(ds.collect().toSeq == Seq(ClassData("2", 2), ClassData("3", 3))) + } } From 8d6a6ffd1a048f3941bb6e5f36e3f84755fc9760 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 21 Nov 2015 14:04:34 +0800 Subject: [PATCH 2/7] simplify the implmenetation --- .../spark/sql/catalyst/ScalaReflection.scala | 45 +++++++++++++--- .../catalyst/encoders/ExpressionEncoder.scala | 54 ++----------------- .../encoders/EncoderResolveSuite.scala | 2 +- 3 files changed, 43 insertions(+), 58 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 4a4a62ed1a46..aa2185909038 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -124,17 +124,46 @@ object ScalaReflection extends ScalaReflection { path: Option[Expression]): Expression = ScalaReflectionLock.synchronized { /** Returns the current path with a sub-field extracted. */ - def addToPath(part: String): Expression = path - .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) - .getOrElse(UnresolvedAttribute(part)) + def addToPath(part: String, dataType: DataType): Expression = { + val newPath = path + .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) + .getOrElse(UnresolvedAttribute(part)) + castToExpectedType(newPath, dataType) + } /** Returns the current path with a field at ordinal extracted. */ - def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path - .map(p => GetInternalRowField(p, ordinal, dataType)) - .getOrElse(BoundReference(ordinal, dataType, false)) + def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = { + val newPath = path + .map(p => GetStructField(p, new StructField("", dataType), ordinal)) + .getOrElse(BoundReference(ordinal, dataType, false)) + castToExpectedType(newPath, dataType) + } /** Returns the current path or `BoundReference`. */ - def getPath: Expression = path.getOrElse(BoundReference(0, schemaFor(tpe).dataType, true)) + def getPath: Expression = { + val dataType = schemaFor(tpe).dataType + path.getOrElse(castToExpectedType(BoundReference(0, dataType, true), dataType)) + } + + /** + * When we build the `fromRowExpression` for an encoder, we set up a lot of "unresolved" stuff + * and lost the required data type, which may lead to runtime error if the real type doesn't + * match the encoder's schema. + * For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type + * is [a: int, b: long], then we will hit runtime error and say that we can't construct class + * `Data` with int and long, because we lost the information that `b` should be a string. + * + * This method help us "remember" the require data type by adding a `Cast`. Note that we don't + * need to add `Cast` for struct type because there must be `UnresolvedExtractValue` or + * `GetStructField` wrapping it. + * + * TODO: this only works if the real type is compatible with the encoder's schema, we should + * also handle error cases. + */ + def castToExpectedType(expr: Expression, expected: DataType): Expression = expected match { + case _: StructType => expr + case _ => Cast(expr, expected) + } tpe match { case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath @@ -302,7 +331,7 @@ object ScalaReflection extends ScalaReflection { if (cls.getName startsWith "scala.Tuple") { constructorFor(fieldType, Some(addToPathOrdinal(i, dataType))) } else { - constructorFor(fieldType, Some(addToPath(fieldName))) + constructorFor(fieldType, Some(addToPath(fieldName, dataType))) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 6cd3a64d6b2d..3704c43de9ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.types.{DataType, StructField, ObjectType, StructType} +import org.apache.spark.sql.types.{StructField, ObjectType, StructType} /** * A factory for constructing encoders that convert objects and primitives to and from the @@ -211,63 +211,19 @@ case class ExpressionEncoder[T]( }) } - private def handleStruct(input: Expression, s: StructType): Expression = { - assert(input.isInstanceOf[NewInstance] || input.isInstanceOf[CreateExternalRow]) - val children = input.children - assert(children.length == s.length) - - val newChildren = children.zip(s.map(_.dataType)).map { - case (child, dt) => typeCast(child, dt) - } - - input.withNewChildren(newChildren) - } - - private def typeCast(input: Expression, expectedType: DataType): Expression = expectedType match { - case s: StructType => - var continue = true - input transformDown { - case c: CreateExternalRow if continue => - continue = false - handleStruct(c, s) - case n: NewInstance if continue => - continue = false - handleStruct(n, s) - } - - case _ => - var continue = true - input transformDown { - case u: UnresolvedExtractValue if continue => - continue = false - Cast(u, expectedType) - case g: GetInternalRowField if continue => - continue = false - Cast(g, expectedType) - case u: UnresolvedAttribute if continue => - continue = false - Cast(u, expectedType) - case a: AttributeReference if continue => - continue = false - Cast(a, expectedType) - } - } - /** * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the * given schema. */ def resolve( - attrs: Seq[Attribute], + schema: Seq[Attribute], outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = { - val positionToAttribute = AttributeMap.toIndex(attrs) - val unbound = fromRowExpression transformUp { + val positionToAttribute = AttributeMap.toIndex(schema) + val unbound = fromRowExpression transform { case b: BoundReference => positionToAttribute(b.ordinal) } - val withTypeCast = typeCast(unbound, if (flat) schema.head.dataType else schema) - - val plan = Project(Alias(withTypeCast, "")() :: Nil, LocalRelation(attrs)) + val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema)) val analyzedPlan = SimpleAnalyzer.execute(plan) val optimizedPlan = SimplifyCasts(analyzedPlan) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolveSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolveSuite.scala index b7e3446fc962..10321df69262 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolveSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolveSuite.scala @@ -90,7 +90,7 @@ class EncoderResolveSuite extends PlanTest { ExpressionEncoder[Long]) val cls = classOf[StringLongClass] - val structType = new StructType().add("a", StringType).add("b", ByteType, false) + val structType = new StructType().add("a", StringType).add("b", ByteType) val attrs = Seq('a.struct(structType), 'b.int) val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression val expected: Expression = NewInstance( From 7c5622392fe3f3789009b4b64534a8c310811b9b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 24 Nov 2015 13:56:45 +0800 Subject: [PATCH 3/7] add UpCast --- .../spark/sql/catalyst/ScalaReflection.scala | 19 +++++------ .../sql/catalyst/analysis/Analyzer.scala | 34 ++++++++++++++++++- .../catalyst/analysis/HiveTypeCoercion.scala | 2 +- .../spark/sql/catalyst/expressions/Cast.scala | 9 +++-- .../plans/logical/basicOperators.scala | 23 +++++++------ .../apache/spark/sql/types/DecimalType.scala | 12 +++++++ .../encoders/EncoderResolveSuite.scala | 14 ++++++++ .../org/apache/spark/sql/GroupedDataset.scala | 4 ++- .../spark/sql/execution/basicOperators.scala | 16 +++++---- .../spark/sql/DatasetAggregatorSuite.scala | 4 +-- .../org/apache/spark/sql/DatasetSuite.scala | 16 ++++----- 11 files changed, 111 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index b8669f8a7ad8..4de7e30c7341 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -18,9 +18,8 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.analysis.{UnresolvedExtractValue, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, ArrayData, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -128,7 +127,7 @@ object ScalaReflection extends ScalaReflection { val newPath = path .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) .getOrElse(UnresolvedAttribute(part)) - castToExpectedType(newPath, dataType) + upCastToExpectedType(newPath, dataType) } /** Returns the current path with a field at ordinal extracted. */ @@ -136,13 +135,13 @@ object ScalaReflection extends ScalaReflection { val newPath = path .map(p => GetStructField(p, new StructField("", dataType), ordinal)) .getOrElse(BoundReference(ordinal, dataType, false)) - castToExpectedType(newPath, dataType) + upCastToExpectedType(newPath, dataType) } /** Returns the current path or `BoundReference`. */ def getPath: Expression = { val dataType = schemaFor(tpe).dataType - path.getOrElse(castToExpectedType(BoundReference(0, dataType, true), dataType)) + path.getOrElse(upCastToExpectedType(BoundReference(0, dataType, true), dataType)) } /** @@ -153,16 +152,16 @@ object ScalaReflection extends ScalaReflection { * is [a: int, b: long], then we will hit runtime error and say that we can't construct class * `Data` with int and long, because we lost the information that `b` should be a string. * - * This method help us "remember" the require data type by adding a `Cast`. Note that we don't - * need to add `Cast` for struct type because there must be `UnresolvedExtractValue` or - * `GetStructField` wrapping it. + * This method help us "remember" the require data type by adding a `UpCast`. Note that we + * don't need to cast struct type because there must be `UnresolvedExtractValue` or + * `GetStructField` wrapping it, and we will need to handle leaf type. * * TODO: this only works if the real type is compatible with the encoder's schema, we should * also handle error cases. */ - def castToExpectedType(expr: Expression, expected: DataType): Expression = expected match { + def upCastToExpectedType(expr: Expression, expected: DataType): Expression = expected match { case _: StructType => expr - case _ => Cast(expr, expected) + case _ => UpCast(expr, expected) } tpe match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 47962ebe6ef8..06e65a782d2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -88,7 +88,8 @@ class Analyzer( Batch("UDF", Once, HandleNullInputsForUDF), Batch("Cleanup", fixedPoint, - CleanupAliases) + CleanupAliases, + RemoveUpCast) ) /** @@ -1169,3 +1170,34 @@ object ComputeCurrentTime extends Rule[LogicalPlan] { } } } + +/** + * Replace the `UpCast` expression by `Cast`, and throw exceptions if the cast may truncate. + */ +object RemoveUpCast extends Rule[LogicalPlan] { + private def fail(from: DataType, to: DataType) = { + throw new AnalysisException( + s"Cannot up cast ${from.simpleString} to ${to.simpleString} as it may truncate") + } + + private def checkNumericPrecedence(from: DataType, to: DataType): Boolean = { + val fromPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(from) + val toPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(to) + if (toPrecedence > 0 && fromPrecedence > toPrecedence) { + false + } else { + true + } + } + + def apply(plan: LogicalPlan): LogicalPlan = { + plan transformAllExpressions { + case UpCast(child, dataType) => (child.dataType, dataType) match { + case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) => fail(from, to) + case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) => fail(from, to) + case (from, to) if !checkNumericPrecedence(from, to) => fail(from, to) + case _ => Cast(child, dataType) + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index f90fc3cc1218..29502a59915f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -53,7 +53,7 @@ object HiveTypeCoercion { // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. // The conversion for integral and floating point types have a linear widening hierarchy: - private val numericPrecedence = + private[sql] val numericPrecedence = IndexedSeq( ByteType, ShortType, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 533d17ea5c17..491cad08367f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -104,8 +104,7 @@ object Cast { } /** Cast the child expression to the target data type. */ -case class Cast(child: Expression, dataType: DataType) - extends UnaryExpression with CodegenFallback { +case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { override def toString: String = s"cast($child as ${dataType.simpleString})" @@ -915,3 +914,9 @@ case class Cast(child: Expression, dataType: DataType) """ } } + +/** + * Cast the child expression to the target data type, but will throw error if the cast might + * truncate, e.g. long -> int, timestamp -> data. + */ +case class UpCast(child: Expression, dataType: DataType) extends UnaryExpression with Unevaluable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 0c444482c5e4..c867db2fee6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -548,19 +548,22 @@ case class MapGroups[K, T, U]( /** Factory for constructing new `CoGroup` nodes. */ object CoGroup { - def apply[K : Encoder, Left : Encoder, Right : Encoder, R : Encoder]( - func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R], + def apply[Key, Left, Right, Result : Encoder]( + func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], + keyEnc: ExpressionEncoder[Key], + leftEnc: ExpressionEncoder[Left], + rightEnc: ExpressionEncoder[Right], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], left: LogicalPlan, - right: LogicalPlan): CoGroup[K, Left, Right, R] = { + right: LogicalPlan): CoGroup[Key, Left, Right, Result] = { CoGroup( func, - encoderFor[K], - encoderFor[Left], - encoderFor[Right], - encoderFor[R], - encoderFor[R].schema.toAttributes, + keyEnc, + leftEnc, + rightEnc, + encoderFor[Result], + encoderFor[Result].schema.toAttributes, leftGroup, rightGroup, left, @@ -574,10 +577,10 @@ object CoGroup { */ case class CoGroup[K, Left, Right, R]( func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R], - kEncoder: ExpressionEncoder[K], + keyEnc: ExpressionEncoder[K], leftEnc: ExpressionEncoder[Left], rightEnc: ExpressionEncoder[Right], - rEncoder: ExpressionEncoder[R], + resultEnc: ExpressionEncoder[R], output: Seq[Attribute], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 0cd352d0fa92..ce45245b9f6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -90,6 +90,18 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { case _ => false } + /** + * Returns whether this DecimalType is tighter than `other`. If yes, it means `this` + * can be casted into `other` safely without losing any precision or range. + */ + private[sql] def isTighterThan(other: DataType): Boolean = other match { + case dt: DecimalType => + (precision - scale) <= (dt.precision - dt.scale) && scale <= dt.scale + case dt: IntegralType => + isTighterThan(DecimalType.forType(dt)) + case _ => false + } + /** * The default size of a value of the DecimalType is 4096 bytes. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolveSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolveSuite.scala index 10321df69262..f130358e139f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolveSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolveSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.encoders +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest @@ -24,6 +25,8 @@ import org.apache.spark.sql.types._ case class StringLongClass(a: String, b: Long) +case class StringIntClass(a: String, b: Int) + case class ComplexClass(a: Long, b: StringLongClass) class EncoderResolveSuite extends PlanTest { @@ -118,4 +121,15 @@ class EncoderResolveSuite extends PlanTest { private def toExternalString(e: Expression): Expression = { Invoke(e, "toString", ObjectType(classOf[String]), Nil) } + + test("throw exception if real type is not compatible with encoder schema") { + intercept[AnalysisException] { + ExpressionEncoder[StringIntClass].resolve(Seq('a.string, 'b.long), null) + } + + intercept[AnalysisException] { + val structType = new StructType().add("a", StringType).add("b", DecimalType.SYSTEM_DEFAULT) + ExpressionEncoder[ComplexClass].resolve(Seq('a.long, 'b.struct(structType)), null) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 7f43ce16901b..c09e383f0377 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -304,11 +304,13 @@ class GroupedDataset[K, V] private[sql]( def cogroup[U, R : Encoder]( other: GroupedDataset[K, U])( f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { - implicit def uEnc: Encoder[U] = other.unresolvedTEncoder new Dataset[R]( sqlContext, CoGroup( f, + resolvedKEncoder, + this.resolvedTEncoder, + other.resolvedTEncoder, this.groupingAttributes, other.groupingAttributes, this.logicalPlan, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index e79092efdaa3..7fdfe67309b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -372,10 +372,10 @@ case class MapGroups[K, T, U]( */ case class CoGroup[K, Left, Right, R]( func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R], - kEncoder: ExpressionEncoder[K], + keyEnc: ExpressionEncoder[K], leftEnc: ExpressionEncoder[Left], rightEnc: ExpressionEncoder[Right], - rEncoder: ExpressionEncoder[R], + resultEnc: ExpressionEncoder[R], output: Seq[Attribute], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], @@ -392,15 +392,17 @@ case class CoGroup[K, Left, Right, R]( left.execute().zipPartitions(right.execute()) { (leftData, rightData) => val leftGrouped = GroupedIterator(leftData, leftGroup, left.output) val rightGrouped = GroupedIterator(rightData, rightGroup, right.output) - val groupKeyEncoder = kEncoder.bind(leftGroup) + val boundKeyEnc = keyEnc.bind(leftGroup) + val boundLeftEnc = leftEnc.bind(left.output) + val boundRightEnc = rightEnc.bind(right.output) new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap { case (key, leftResult, rightResult) => val result = func( - groupKeyEncoder.fromRow(key), - leftResult.map(leftEnc.fromRow), - rightResult.map(rightEnc.fromRow)) - result.map(rEncoder.toRow) + boundKeyEnc.fromRow(key), + leftResult.map(boundLeftEnc.fromRow), + rightResult.map(boundRightEnc.fromRow)) + result.map(resultEnc.toRow) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 19dce5d1e2f3..c6d2bf07b280 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -131,9 +131,9 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { checkAnswer( ds.groupBy(_._1).agg( sum(_._2), - expr("sum(_2)").as[Int], + expr("sum(_2)").as[Long], count("*")), - ("a", 30, 30, 2L), ("b", 3, 3, 2L), ("c", 1, 1, 1L)) + ("a", 30, 30L, 2L), ("b", 3, 3L, 2L), ("c", 1, 1L, 1L)) } test("typed aggregation: complex case") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index ea7dcbf300a6..bf879a55aa97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -296,24 +296,24 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() checkAnswer( - ds.groupBy(_._1).agg(sum("_2").as[Int]), - ("a", 30), ("b", 3), ("c", 1)) + ds.groupBy(_._1).agg(sum("_2").as[Long]), + ("a", 30L), ("b", 3L), ("c", 1L)) } test("typed aggregation: expr, expr") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() checkAnswer( - ds.groupBy(_._1).agg(sum("_2").as[Int], sum($"_2" + 1).as[Long]), - ("a", 30, 32L), ("b", 3, 5L), ("c", 1, 2L)) + ds.groupBy(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long]), + ("a", 30L, 32L), ("b", 3L, 5L), ("c", 1L, 2L)) } test("typed aggregation: expr, expr, expr") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() checkAnswer( - ds.groupBy(_._1).agg(sum("_2").as[Int], sum($"_2" + 1).as[Long], count("*").as[Long]), - ("a", 30, 32L, 2L), ("b", 3, 5L, 2L), ("c", 1, 2L, 1L)) + ds.groupBy(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long], count("*")), + ("a", 30L, 32L, 2L), ("b", 3L, 5L, 2L), ("c", 1L, 2L, 1L)) } test("typed aggregation: expr, expr, expr, expr") { @@ -321,11 +321,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkAnswer( ds.groupBy(_._1).agg( - sum("_2").as[Int], + sum("_2").as[Long], sum($"_2" + 1).as[Long], count("*").as[Long], avg("_2").as[Double]), - ("a", 30, 32L, 2L, 15.0), ("b", 3, 5L, 2L, 1.5), ("c", 1, 2L, 1L, 1.0)) + ("a", 30L, 32L, 2L, 15.0), ("b", 3L, 5L, 2L, 1.5), ("c", 1L, 2L, 1L, 1.0)) } test("cogroup") { From 6c9dc1e22fb88229247cf1bb284c06715b089da3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 25 Nov 2015 13:10:41 +0800 Subject: [PATCH 4/7] address comments --- .../spark/sql/catalyst/ScalaReflection.scala | 7 +- .../sql/catalyst/analysis/Analyzer.scala | 25 ++--- ...ite.scala => EncoderResolutionSuite.scala} | 102 +++++++++++------- 3 files changed, 79 insertions(+), 55 deletions(-) rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/{EncoderResolveSuite.scala => EncoderResolutionSuite.scala} (58%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 08bb0004d532..ed0e7d5c7efe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -152,12 +152,9 @@ object ScalaReflection extends ScalaReflection { * is [a: int, b: long], then we will hit runtime error and say that we can't construct class * `Data` with int and long, because we lost the information that `b` should be a string. * - * This method help us "remember" the require data type by adding a `UpCast`. Note that we + * This method help us "remember" the required data type by adding a `UpCast`. Note that we * don't need to cast struct type because there must be `UnresolvedExtractValue` or - * `GetStructField` wrapping it, and we will need to handle leaf type. - * - * TODO: this only works if the real type is compatible with the encoder's schema, we should - * also handle error cases. + * `GetStructField` wrapping it, thus we only need to handle leaf type. */ def upCastToExpectedType(expr: Expression, expected: DataType): Expression = expected match { case _: StructType => expr diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 06e65a782d2a..55887b8df5da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -89,7 +89,7 @@ class Analyzer( HandleNullInputsForUDF), Batch("Cleanup", fixedPoint, CleanupAliases, - RemoveUpCast) + ResolveUpCast) ) /** @@ -1174,28 +1174,25 @@ object ComputeCurrentTime extends Rule[LogicalPlan] { /** * Replace the `UpCast` expression by `Cast`, and throw exceptions if the cast may truncate. */ -object RemoveUpCast extends Rule[LogicalPlan] { - private def fail(from: DataType, to: DataType) = { - throw new AnalysisException( - s"Cannot up cast ${from.simpleString} to ${to.simpleString} as it may truncate") +object ResolveUpCast extends Rule[LogicalPlan] { + private def fail(from: Expression, to: DataType) = { + throw new AnalysisException(s"Cannot up cast `${from.prettyString}` from " + + s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate") } - private def checkNumericPrecedence(from: DataType, to: DataType): Boolean = { + private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = { val fromPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(from) val toPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(to) - if (toPrecedence > 0 && fromPrecedence > toPrecedence) { - false - } else { - true - } + toPrecedence > 0 && fromPrecedence > toPrecedence } def apply(plan: LogicalPlan): LogicalPlan = { plan transformAllExpressions { case UpCast(child, dataType) => (child.dataType, dataType) match { - case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) => fail(from, to) - case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) => fail(from, to) - case (from, to) if !checkNumericPrecedence(from, to) => fail(from, to) + case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) => fail(child, to) + case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) => fail(child, to) + case (from, to) if illegalNumericPrecedence(from, to) => fail(child, to) + case (TimestampType, DateType) => fail(child, DateType) case _ => Cast(child, dataType) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolveSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala similarity index 58% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolveSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index f130358e139f..27bbfbc33f7c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolveSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.encoders +import scala.reflect.runtime.universe.TypeTag + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ @@ -29,28 +31,32 @@ case class StringIntClass(a: String, b: Int) case class ComplexClass(a: Long, b: StringLongClass) -class EncoderResolveSuite extends PlanTest { +class EncoderResolutionSuite extends PlanTest { test("real type doesn't match encoder schema but they are compatible: product") { val encoder = ExpressionEncoder[StringLongClass] val cls = classOf[StringLongClass] - var attrs = Seq('a.string, 'b.int) - var fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression - var expected: Expression = NewInstance( - cls, - toExternalString('a.string) :: 'b.int.cast(LongType) :: Nil, - false, - ObjectType(cls)) - compareExpressions(fromRowExpr, expected) + { + val attrs = Seq('a.string, 'b.int) + val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression + val expected: Expression = NewInstance( + cls, + toExternalString('a.string) :: 'b.int.cast(LongType) :: Nil, + false, + ObjectType(cls)) + compareExpressions(fromRowExpr, expected) + } - attrs = Seq('a.int, 'b.long) - fromRowExpr = encoder.resolve(attrs, null).fromRowExpression - expected = NewInstance( - cls, - toExternalString('a.int.cast(StringType)) :: 'b.long :: Nil, - false, - ObjectType(cls)) - compareExpressions(fromRowExpr, expected) + { + val attrs = Seq('a.int, 'b.long) + val fromRowExpr = encoder.resolve(attrs, null).fromRowExpression + val expected = NewInstance( + cls, + toExternalString('a.int.cast(StringType)) :: 'b.long :: Nil, + false, + ObjectType(cls)) + compareExpressions(fromRowExpr, expected) + } } test("real type doesn't match encoder schema but they are compatible: nested product") { @@ -71,14 +77,9 @@ class EncoderResolveSuite extends PlanTest { NewInstance( innerCls, Seq( - toExternalString(GetStructField( - 'b.struct(structType), - structType(0), - 0).cast(StringType)), - GetStructField( - 'b.struct(structType), - structType(1), - 1)), + toExternalString( + GetStructField('b.struct(structType), 0, Some("a")).cast(StringType)), + GetStructField('b.struct(structType), 1, Some("b"))), false, ObjectType(innerCls)) )), @@ -102,14 +103,8 @@ class EncoderResolveSuite extends PlanTest { NewInstance( cls, Seq( - toExternalString(GetStructField( - 'a.struct(structType), - structType(0), - 0)), - GetStructField( - 'a.struct(structType), - structType(1), - 1).cast(LongType)), + toExternalString(GetStructField('a.struct(structType), 0, Some("a"))), + GetStructField('a.struct(structType), 1, Some("b")).cast(LongType)), false, ObjectType(cls)), 'b.int.cast(LongType)), @@ -123,13 +118,48 @@ class EncoderResolveSuite extends PlanTest { } test("throw exception if real type is not compatible with encoder schema") { - intercept[AnalysisException] { + val msg1 = intercept[AnalysisException] { ExpressionEncoder[StringIntClass].resolve(Seq('a.string, 'b.long), null) - } + }.message + assert(msg1.contains("Cannot up cast `b` from bigint to int as it may truncate")) - intercept[AnalysisException] { + val msg2 = intercept[AnalysisException] { val structType = new StructType().add("a", StringType).add("b", DecimalType.SYSTEM_DEFAULT) ExpressionEncoder[ComplexClass].resolve(Seq('a.long, 'b.struct(structType)), null) + }.message + assert(msg2.contains("Cannot up cast `b.b` from decimal(38,18) to bigint as it may truncate")) + + } + + // test for leaf types + castSuccess[Int, Long] + castSuccess[java.sql.Date, java.sql.Timestamp] + castSuccess[Long, String] + castSuccess[String, Long] + castSuccess[Int, java.math.BigDecimal] + castSuccess[Long, java.math.BigDecimal] + + castFail[Long, Int] + castFail[java.sql.Timestamp, java.sql.Date] + castFail[java.math.BigDecimal, Double] + castFail[Double, java.math.BigDecimal] + castFail[java.math.BigDecimal, Int] + + private def castSuccess[T: TypeTag, U: TypeTag]: Unit = { + val from = ExpressionEncoder[T] + val to = ExpressionEncoder[U] + val catalystType = from.schema.head.dataType.simpleString + test(s"cast from $catalystType to ${implicitly[TypeTag[U]].tpe} should success") { + to.resolve(from.schema.toAttributes, null) + } + } + + private def castFail[T: TypeTag, U: TypeTag]: Unit = { + val from = ExpressionEncoder[T] + val to = ExpressionEncoder[U] + val catalystType = from.schema.head.dataType.simpleString + test(s"cast from $catalystType to ${implicitly[TypeTag[U]].tpe} should fail") { + intercept[AnalysisException](to.resolve(from.schema.toAttributes, null)) } } } From 399d812b4970fe3c20ad1f93a33120e859bc8985 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 26 Nov 2015 18:39:10 +0800 Subject: [PATCH 5/7] improves --- .../spark/sql/catalyst/ScalaReflection.scala | 61 +++++++++++++------ .../sql/catalyst/analysis/Analyzer.scala | 21 ++++--- .../spark/sql/catalyst/expressions/Cast.scala | 3 +- .../encoders/EncoderResolutionSuite.scala | 20 +++++- 4 files changed, 77 insertions(+), 28 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index ed0e7d5c7efe..dc0b999a7642 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -116,32 +116,42 @@ object ScalaReflection extends ScalaReflection { * from ordinal 0 (since there are no names to map to). The actual location can be moved by * calling resolve/bind with a new schema. */ - def constructorFor[T : TypeTag]: Expression = constructorFor(localTypeOf[T], None) + def constructorFor[T : TypeTag]: Expression = { + val tpe = localTypeOf[T] + val clsName = getClassNameFromType(tpe) + val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil + constructorFor(tpe, None, walkedTypePath) + } private def constructorFor( tpe: `Type`, - path: Option[Expression]): Expression = ScalaReflectionLock.synchronized { + path: Option[Expression], + walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized { /** Returns the current path with a sub-field extracted. */ - def addToPath(part: String, dataType: DataType): Expression = { + def addToPath(part: String, dataType: DataType, walkedTypePath: Seq[String]): Expression = { val newPath = path .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) .getOrElse(UnresolvedAttribute(part)) - upCastToExpectedType(newPath, dataType) + upCastToExpectedType(newPath, dataType, walkedTypePath) } /** Returns the current path with a field at ordinal extracted. */ - def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = { + def addToPathOrdinal(ordinal: Int, dataType: DataType, walkedTypePath: Seq[String]): Expression = { val newPath = path .map(p => GetStructField(p, ordinal)) .getOrElse(BoundReference(ordinal, dataType, false)) - upCastToExpectedType(newPath, dataType) + upCastToExpectedType(newPath, dataType, walkedTypePath) } /** Returns the current path or `BoundReference`. */ def getPath: Expression = { val dataType = schemaFor(tpe).dataType - path.getOrElse(upCastToExpectedType(BoundReference(0, dataType, true), dataType)) + if (path.isDefined) { + path.get + } else { + upCastToExpectedType(BoundReference(0, dataType, true), dataType, walkedTypePath) + } } /** @@ -156,9 +166,12 @@ object ScalaReflection extends ScalaReflection { * don't need to cast struct type because there must be `UnresolvedExtractValue` or * `GetStructField` wrapping it, thus we only need to handle leaf type. */ - def upCastToExpectedType(expr: Expression, expected: DataType): Expression = expected match { + def upCastToExpectedType( + expr: Expression, + expected: DataType, + walkedTypePath: Seq[String]): Expression = expected match { case _: StructType => expr - case _ => UpCast(expr, expected) + case _ => UpCast(expr, expected, walkedTypePath) } tpe match { @@ -166,7 +179,9 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t - WrapOption(constructorFor(optType, path)) + val className = getClassNameFromType(optType) + val newTypePath = s"""- option value class: "$className"""" +: walkedTypePath + WrapOption(constructorFor(optType, path, newTypePath)) case t if t <:< localTypeOf[java.lang.Integer] => val boxedType = classOf[java.lang.Integer] @@ -244,9 +259,11 @@ object ScalaReflection extends ScalaReflection { primitiveMethod.map { method => Invoke(getPath, method, arrayClassFor(elementType)) }.getOrElse { + val className = getClassNameFromType(elementType) + val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath Invoke( MapObjects( - p => constructorFor(elementType, Some(p)), + p => constructorFor(elementType, Some(p), newTypePath), getPath, schemaFor(elementType).dataType), "array", @@ -255,10 +272,12 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t + val className = getClassNameFromType(elementType) + val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath val arrayData = Invoke( MapObjects( - p => constructorFor(elementType, Some(p)), + p => constructorFor(elementType, Some(p), newTypePath), getPath, schemaFor(elementType).dataType), "array", @@ -271,12 +290,13 @@ object ScalaReflection extends ScalaReflection { arrayData :: Nil) case t if t <:< localTypeOf[Map[_, _]] => + // TODO: add walked type path for map val TypeRef(_, _, Seq(keyType, valueType)) = t val keyData = Invoke( MapObjects( - p => constructorFor(keyType, Some(p)), + p => constructorFor(keyType, Some(p), walkedTypePath), Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)), schemaFor(keyType).dataType), "array", @@ -285,7 +305,7 @@ object ScalaReflection extends ScalaReflection { val valueData = Invoke( MapObjects( - p => constructorFor(valueType, Some(p)), + p => constructorFor(valueType, Some(p), walkedTypePath), Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)), schemaFor(valueType).dataType), "array", @@ -322,12 +342,19 @@ object ScalaReflection extends ScalaReflection { val fieldName = p.name.toString val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) val dataType = schemaFor(fieldType).dataType - + val clsName = getClassNameFromType(fieldType) + val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath // For tuples, we based grab the inner fields by ordinal instead of name. if (cls.getName startsWith "scala.Tuple") { - constructorFor(fieldType, Some(addToPathOrdinal(i, dataType))) + constructorFor( + fieldType, + Some(addToPathOrdinal(i, dataType, newTypePath)), + newTypePath) } else { - constructorFor(fieldType, Some(addToPath(fieldName, dataType))) + constructorFor( + fieldType, + Some(addToPath(fieldName, dataType, newTypePath)), + newTypePath) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 55887b8df5da..eab87e575c7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1175,9 +1175,12 @@ object ComputeCurrentTime extends Rule[LogicalPlan] { * Replace the `UpCast` expression by `Cast`, and throw exceptions if the cast may truncate. */ object ResolveUpCast extends Rule[LogicalPlan] { - private def fail(from: Expression, to: DataType) = { + private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = { throw new AnalysisException(s"Cannot up cast `${from.prettyString}` from " + - s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate") + s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" + + "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") + + "You can either add an explicit cast to the input data or choose a higher precision " + + "type of the field in the target object") } private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = { @@ -1188,11 +1191,15 @@ object ResolveUpCast extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { plan transformAllExpressions { - case UpCast(child, dataType) => (child.dataType, dataType) match { - case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) => fail(child, to) - case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) => fail(child, to) - case (from, to) if illegalNumericPrecedence(from, to) => fail(child, to) - case (TimestampType, DateType) => fail(child, DateType) + case UpCast(child, dataType, walkedTypePath) => (child.dataType, dataType) match { + case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) => + fail(child, to, walkedTypePath) + case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) => + fail(child, to, walkedTypePath) + case (from, to) if illegalNumericPrecedence(from, to) => + fail(child, to, walkedTypePath) + case (TimestampType, DateType) => + fail(child, DateType, walkedTypePath) case _ => Cast(child, dataType) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 491cad08367f..9ce367cdcd90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -919,4 +919,5 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { * Cast the child expression to the target data type, but will throw error if the cast might * truncate, e.g. long -> int, timestamp -> data. */ -case class UpCast(child: Expression, dataType: DataType) extends UnaryExpression with Unevaluable +case class UpCast(child: Expression, dataType: DataType, walkedTypePath: Seq[String]) + extends UnaryExpression with Unevaluable diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 27bbfbc33f7c..5b502485fb77 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -121,14 +121,28 @@ class EncoderResolutionSuite extends PlanTest { val msg1 = intercept[AnalysisException] { ExpressionEncoder[StringIntClass].resolve(Seq('a.string, 'b.long), null) }.message - assert(msg1.contains("Cannot up cast `b` from bigint to int as it may truncate")) + assert(msg1 == + s""" + |Cannot up cast `b` from bigint to int as it may truncate + |The type path of the target object is: + |- field (class: "scala.Int", name: "b") + |- root class: "org.apache.spark.sql.catalyst.encoders.StringIntClass" + |You can either add an explicit cast to the input data or choose a higher precision type + """.stripMargin.trim + " of the field in the target object") val msg2 = intercept[AnalysisException] { val structType = new StructType().add("a", StringType).add("b", DecimalType.SYSTEM_DEFAULT) ExpressionEncoder[ComplexClass].resolve(Seq('a.long, 'b.struct(structType)), null) }.message - assert(msg2.contains("Cannot up cast `b.b` from decimal(38,18) to bigint as it may truncate")) - + assert(msg2 == + s""" + |Cannot up cast `b.b` from decimal(38,18) to bigint as it may truncate + |The type path of the target object is: + |- field (class: "scala.Long", name: "b") + |- field (class: "org.apache.spark.sql.catalyst.encoders.StringLongClass", name: "b") + |- root class: "org.apache.spark.sql.catalyst.encoders.ComplexClass" + |You can either add an explicit cast to the input data or choose a higher precision type + """.stripMargin.trim + " of the field in the target object") } // test for leaf types From 2f7370c33ddda84e306a73d478c6cf470e04837f Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 27 Nov 2015 11:35:11 +0800 Subject: [PATCH 6/7] update --- .../org/apache/spark/sql/catalyst/ScalaReflection.scala | 5 ++++- .../main/scala/org/apache/spark/sql/types/DecimalType.scala | 4 ++-- .../spark/sql/catalyst/encoders/EncoderResolutionSuite.scala | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index dc0b999a7642..9b6b5b8bd1a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -137,7 +137,10 @@ object ScalaReflection extends ScalaReflection { } /** Returns the current path with a field at ordinal extracted. */ - def addToPathOrdinal(ordinal: Int, dataType: DataType, walkedTypePath: Seq[String]): Expression = { + def addToPathOrdinal( + ordinal: Int, + dataType: DataType, + walkedTypePath: Seq[String]): Expression = { val newPath = path .map(p => GetStructField(p, ordinal)) .getOrElse(BoundReference(ordinal, dataType, false)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index ce45245b9f6d..3a845e9865ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -85,7 +85,7 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { private[sql] def isWiderThan(other: DataType): Boolean = other match { case dt: DecimalType => (precision - scale) >= (dt.precision - dt.scale) && scale >= dt.scale - case dt: IntegralType => + case dt: NumericType => isWiderThan(DecimalType.forType(dt)) case _ => false } @@ -97,7 +97,7 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { private[sql] def isTighterThan(other: DataType): Boolean = other match { case dt: DecimalType => (precision - scale) <= (dt.precision - dt.scale) && scale <= dt.scale - case dt: IntegralType => + case dt: NumericType => isTighterThan(DecimalType.forType(dt)) case _ => false } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 5b502485fb77..9d2aabc0992e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -152,11 +152,11 @@ class EncoderResolutionSuite extends PlanTest { castSuccess[String, Long] castSuccess[Int, java.math.BigDecimal] castSuccess[Long, java.math.BigDecimal] + castSuccess[Double, java.math.BigDecimal] castFail[Long, Int] castFail[java.sql.Timestamp, java.sql.Date] castFail[java.math.BigDecimal, Double] - castFail[Double, java.math.BigDecimal] castFail[java.math.BigDecimal, Int] private def castSuccess[T: TypeTag, U: TypeTag]: Unit = { From 57b0d7e5df19d7eda1fcc3ac54f6d3f96890d366 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 1 Dec 2015 10:15:27 +0800 Subject: [PATCH 7/7] address comments --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 8 ++++++-- .../org/apache/spark/sql/catalyst/expressions/Cast.scala | 4 +++- .../scala/org/apache/spark/sql/types/DecimalType.scala | 4 ++-- .../sql/catalyst/encoders/EncoderResolutionSuite.scala | 5 +++-- 4 files changed, 14 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index eab87e575c7b..a46942b8bcfe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -72,6 +72,7 @@ class Analyzer( ResolveReferences :: ResolveGroupingAnalytics :: ResolvePivot :: + ResolveUpCast :: ResolveSortReferences :: ResolveGenerate :: ResolveFunctions :: @@ -88,8 +89,7 @@ class Analyzer( Batch("UDF", Once, HandleNullInputsForUDF), Batch("Cleanup", fixedPoint, - CleanupAliases, - ResolveUpCast) + CleanupAliases) ) /** @@ -1191,6 +1191,8 @@ object ResolveUpCast extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { plan transformAllExpressions { + case u @ UpCast(child, _, _) if !child.resolved => u + case UpCast(child, dataType, walkedTypePath) => (child.dataType, dataType) match { case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) => fail(child, to, walkedTypePath) @@ -1200,6 +1202,8 @@ object ResolveUpCast extends Rule[LogicalPlan] { fail(child, to, walkedTypePath) case (TimestampType, DateType) => fail(child, DateType, walkedTypePath) + case (StringType, to: NumericType) => + fail(child, to, walkedTypePath) case _ => Cast(child, dataType) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 9ce367cdcd90..cb60d5958d53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -920,4 +920,6 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { * truncate, e.g. long -> int, timestamp -> data. */ case class UpCast(child: Expression, dataType: DataType, walkedTypePath: Seq[String]) - extends UnaryExpression with Unevaluable + extends UnaryExpression with Unevaluable { + override lazy val resolved = false +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 3a845e9865ae..ce45245b9f6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -85,7 +85,7 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { private[sql] def isWiderThan(other: DataType): Boolean = other match { case dt: DecimalType => (precision - scale) >= (dt.precision - dt.scale) && scale >= dt.scale - case dt: NumericType => + case dt: IntegralType => isWiderThan(DecimalType.forType(dt)) case _ => false } @@ -97,7 +97,7 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { private[sql] def isTighterThan(other: DataType): Boolean = other match { case dt: DecimalType => (precision - scale) <= (dt.precision - dt.scale) && scale <= dt.scale - case dt: NumericType => + case dt: IntegralType => isTighterThan(DecimalType.forType(dt)) case _ => false } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 9d2aabc0992e..0289988342e7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -149,15 +149,16 @@ class EncoderResolutionSuite extends PlanTest { castSuccess[Int, Long] castSuccess[java.sql.Date, java.sql.Timestamp] castSuccess[Long, String] - castSuccess[String, Long] castSuccess[Int, java.math.BigDecimal] castSuccess[Long, java.math.BigDecimal] - castSuccess[Double, java.math.BigDecimal] castFail[Long, Int] castFail[java.sql.Timestamp, java.sql.Date] castFail[java.math.BigDecimal, Double] + castFail[Double, java.math.BigDecimal] castFail[java.math.BigDecimal, Int] + castFail[String, Long] + private def castSuccess[T: TypeTag, U: TypeTag]: Unit = { val from = ExpressionEncoder[T]