-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-11856][SQL] add type cast if the real type is different but compatible with encoder schema #9840
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-11856][SQL] add type cast if the real type is different but compatible with encoder schema #9840
Changes from 5 commits
e9dbd7b
19dbed2
8d6a6ff
e5d963b
7c56223
211a107
6c9dc1e
399d812
2f7370c
57b0d7e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,9 +18,8 @@ | |
| package org.apache.spark.sql.catalyst | ||
|
|
||
| import org.apache.spark.sql.catalyst.analysis.{UnresolvedExtractValue, UnresolvedAttribute} | ||
| import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, ArrayData, DateTimeUtils} | ||
| import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils} | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.plans.logical.LocalRelation | ||
| import org.apache.spark.sql.types._ | ||
| import org.apache.spark.unsafe.types.UTF8String | ||
| import org.apache.spark.util.Utils | ||
|
|
@@ -124,17 +123,46 @@ object ScalaReflection extends ScalaReflection { | |
| path: Option[Expression]): Expression = ScalaReflectionLock.synchronized { | ||
|
|
||
| /** Returns the current path with a sub-field extracted. */ | ||
| def addToPath(part: String): Expression = path | ||
| .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) | ||
| .getOrElse(UnresolvedAttribute(part)) | ||
| def addToPath(part: String, dataType: DataType): Expression = { | ||
| val newPath = path | ||
| .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) | ||
| .getOrElse(UnresolvedAttribute(part)) | ||
| upCastToExpectedType(newPath, dataType) | ||
| } | ||
|
|
||
| /** Returns the current path with a field at ordinal extracted. */ | ||
| def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path | ||
| .map(p => GetInternalRowField(p, ordinal, dataType)) | ||
| .getOrElse(BoundReference(ordinal, dataType, false)) | ||
| def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = { | ||
| val newPath = path | ||
| .map(p => GetStructField(p, new StructField("", dataType), ordinal)) | ||
| .getOrElse(BoundReference(ordinal, dataType, false)) | ||
| upCastToExpectedType(newPath, dataType) | ||
| } | ||
|
|
||
| /** Returns the current path or `BoundReference`. */ | ||
| def getPath: Expression = path.getOrElse(BoundReference(0, schemaFor(tpe).dataType, true)) | ||
| def getPath: Expression = { | ||
| val dataType = schemaFor(tpe).dataType | ||
| path.getOrElse(upCastToExpectedType(BoundReference(0, dataType, true), dataType)) | ||
| } | ||
|
|
||
| /** | ||
| * When we build the `fromRowExpression` for an encoder, we set up a lot of "unresolved" stuff | ||
| * and lost the required data type, which may lead to runtime error if the real type doesn't | ||
| * match the encoder's schema. | ||
| * For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type | ||
| * is [a: int, b: long], then we will hit runtime error and say that we can't construct class | ||
| * `Data` with int and long, because we lost the information that `b` should be a string. | ||
| * | ||
| * This method help us "remember" the require data type by adding a `UpCast`. Note that we | ||
| * don't need to cast struct type because there must be `UnresolvedExtractValue` or | ||
| * `GetStructField` wrapping it, and we will need to handle leaf type. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| * | ||
| * TODO: this only works if the real type is compatible with the encoder's schema, we should | ||
| * also handle error cases. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When you saying type compatibility, is it like type promotion? Have we defined such rules for type promotion in Spark? Thanks
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure if we want to automatically downcast where we could possibly truncate the values. Unlike an explicit cast, where the user is asking for it, I think this could be confusing. Consider the following: scala> case class Data(value: Int)
scala> Seq(Int.MaxValue.toLong + 1).toDS().as[Data].collect()
res6: Array[Data] = Array(Data(-2147483648))I think we at least want to warn, and probably just throw an error. If this is really what they want then they can cast explicitly.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @gatorsmile "type compatibility" means if we do type cast, the @marmbrus I also thought about this, maybe we can create a different cast operator and define the encoder related rules there?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. stale comment? |
||
| */ | ||
| def upCastToExpectedType(expr: Expression, expected: DataType): Expression = expected match { | ||
| case _: StructType => expr | ||
| case _ => UpCast(expr, expected) | ||
| } | ||
|
|
||
| tpe match { | ||
| case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath | ||
|
|
@@ -302,7 +330,7 @@ object ScalaReflection extends ScalaReflection { | |
| if (cls.getName startsWith "scala.Tuple") { | ||
| constructorFor(fieldType, Some(addToPathOrdinal(i, dataType))) | ||
| } else { | ||
| constructorFor(fieldType, Some(addToPath(fieldName))) | ||
| constructorFor(fieldType, Some(addToPath(fieldName, dataType))) | ||
| } | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -88,7 +88,8 @@ class Analyzer( | |
| Batch("UDF", Once, | ||
| HandleNullInputsForUDF), | ||
| Batch("Cleanup", fixedPoint, | ||
| CleanupAliases) | ||
| CleanupAliases, | ||
| RemoveUpCast) | ||
| ) | ||
|
|
||
| /** | ||
|
|
@@ -1169,3 +1170,34 @@ object ComputeCurrentTime extends Rule[LogicalPlan] { | |
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Replace the `UpCast` expression by `Cast`, and throw exceptions if the cast may truncate. | ||
| */ | ||
| object RemoveUpCast extends Rule[LogicalPlan] { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe |
||
| private def fail(from: DataType, to: DataType) = { | ||
| throw new AnalysisException( | ||
| s"Cannot up cast ${from.simpleString} to ${to.simpleString} as it may truncate") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it might be good to include the awesome path error messages that we use when we fail to resolve here as well. We might also suggest ways to fix it (i.e. "either add an explicit cast to the input data or choose a higher precision type in the target object"). |
||
| } | ||
|
|
||
| private def checkNumericPrecedence(from: DataType, to: DataType): Boolean = { | ||
| val fromPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(from) | ||
| val toPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(to) | ||
| if (toPrecedence > 0 && fromPrecedence > toPrecedence) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| false | ||
| } else { | ||
| true | ||
| } | ||
| } | ||
|
|
||
| def apply(plan: LogicalPlan): LogicalPlan = { | ||
| plan transformAllExpressions { | ||
| case UpCast(child, dataType) => (child.dataType, dataType) match { | ||
| case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) => fail(from, to) | ||
| case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) => fail(from, to) | ||
| case (from, to) if !checkNumericPrecedence(from, to) => fail(from, to) | ||
| case _ => Cast(child, dataType) | ||
| } | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtract | |
| import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} | ||
| import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts | ||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.ScalaReflection | ||
| import org.apache.spark.sql.types.{StructField, ObjectType, StructType} | ||
|
|
@@ -235,12 +236,13 @@ case class ExpressionEncoder[T]( | |
|
|
||
| val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema)) | ||
| val analyzedPlan = SimpleAnalyzer.execute(plan) | ||
| val optimizedPlan = SimplifyCasts(analyzedPlan) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if we should just use the full optimizer here. I guess for now it won't do anything, but since it should never change the answer and we might improve it later that might make more sense.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm also thinking about if we should introduce
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably not a On one hand, Scala/Java are always case sensitive so it seems reasonable to preserve that. On the other hand if you loading from something like hive it would be annoying to have to fix all the columns by hand. @rxin, thoughts?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe encoders should be case sensitive all the time to begin with? It is programming language after all, which is case sensitive. If users complain, we can consider adding them in the future?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SGTM |
||
|
|
||
| // In order to construct instances of inner classes (for example those declared in a REPL cell), | ||
| // we need an instance of the outer scope. This rule substitues those outer objects into | ||
| // expressions that are missing them by looking up the name in the SQLContexts `outerScopes` | ||
| // registry. | ||
| copy(fromRowExpression = analyzedPlan.expressions.head.children.head transform { | ||
| copy(fromRowExpression = optimizedPlan.expressions.head.children.head transform { | ||
| case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass => | ||
| val outer = outerScopes.get(n.cls.getDeclaringClass.getName) | ||
| if (outer == null) { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,135 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.catalyst.encoders | ||
|
|
||
| import org.apache.spark.sql.AnalysisException | ||
| import org.apache.spark.sql.catalyst.dsl.expressions._ | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.plans.PlanTest | ||
| import org.apache.spark.sql.types._ | ||
|
|
||
| case class StringLongClass(a: String, b: Long) | ||
|
|
||
| case class StringIntClass(a: String, b: Int) | ||
|
|
||
| case class ComplexClass(a: Long, b: StringLongClass) | ||
|
|
||
| class EncoderResolveSuite extends PlanTest { | ||
|
||
| test("real type doesn't match encoder schema but they are compatible: product") { | ||
| val encoder = ExpressionEncoder[StringLongClass] | ||
| val cls = classOf[StringLongClass] | ||
|
|
||
| var attrs = Seq('a.string, 'b.int) | ||
| var fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression | ||
| var expected: Expression = NewInstance( | ||
| cls, | ||
| toExternalString('a.string) :: 'b.int.cast(LongType) :: Nil, | ||
| false, | ||
| ObjectType(cls)) | ||
| compareExpressions(fromRowExpr, expected) | ||
|
|
||
| attrs = Seq('a.int, 'b.long) | ||
| fromRowExpr = encoder.resolve(attrs, null).fromRowExpression | ||
| expected = NewInstance( | ||
| cls, | ||
| toExternalString('a.int.cast(StringType)) :: 'b.long :: Nil, | ||
| false, | ||
| ObjectType(cls)) | ||
| compareExpressions(fromRowExpr, expected) | ||
| } | ||
|
|
||
| test("real type doesn't match encoder schema but they are compatible: nested product") { | ||
| val encoder = ExpressionEncoder[ComplexClass] | ||
| val innerCls = classOf[StringLongClass] | ||
| val cls = classOf[ComplexClass] | ||
|
|
||
| val structType = new StructType().add("a", IntegerType).add("b", LongType) | ||
| val attrs = Seq('a.int, 'b.struct(structType)) | ||
| val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression | ||
| val expected: Expression = NewInstance( | ||
| cls, | ||
| Seq( | ||
| 'a.int.cast(LongType), | ||
| If( | ||
| 'b.struct(structType).isNull, | ||
| Literal.create(null, ObjectType(innerCls)), | ||
| NewInstance( | ||
| innerCls, | ||
| Seq( | ||
| toExternalString(GetStructField( | ||
| 'b.struct(structType), | ||
| structType(0), | ||
| 0).cast(StringType)), | ||
| GetStructField( | ||
| 'b.struct(structType), | ||
| structType(1), | ||
| 1)), | ||
| false, | ||
| ObjectType(innerCls)) | ||
| )), | ||
| false, | ||
| ObjectType(cls)) | ||
| compareExpressions(fromRowExpr, expected) | ||
| } | ||
|
|
||
| test("real type doesn't match encoder schema but they are compatible: tupled encoder") { | ||
| val encoder = ExpressionEncoder.tuple( | ||
| ExpressionEncoder[StringLongClass], | ||
| ExpressionEncoder[Long]) | ||
| val cls = classOf[StringLongClass] | ||
|
|
||
| val structType = new StructType().add("a", StringType).add("b", ByteType) | ||
| val attrs = Seq('a.struct(structType), 'b.int) | ||
| val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression | ||
| val expected: Expression = NewInstance( | ||
| classOf[Tuple2[_, _]], | ||
| Seq( | ||
| NewInstance( | ||
| cls, | ||
| Seq( | ||
| toExternalString(GetStructField( | ||
| 'a.struct(structType), | ||
| structType(0), | ||
| 0)), | ||
| GetStructField( | ||
| 'a.struct(structType), | ||
| structType(1), | ||
| 1).cast(LongType)), | ||
| false, | ||
| ObjectType(cls)), | ||
| 'b.int.cast(LongType)), | ||
| false, | ||
| ObjectType(classOf[Tuple2[_, _]])) | ||
| compareExpressions(fromRowExpr, expected) | ||
| } | ||
|
|
||
| private def toExternalString(e: Expression): Expression = { | ||
| Invoke(e, "toString", ObjectType(classOf[String]), Nil) | ||
| } | ||
|
|
||
| test("throw exception if real type is not compatible with encoder schema") { | ||
| intercept[AnalysisException] { | ||
| ExpressionEncoder[StringIntClass].resolve(Seq('a.string, 'b.long), null) | ||
| } | ||
|
|
||
| intercept[AnalysisException] { | ||
| val structType = new StructType().add("a", StringType).add("b", DecimalType.SYSTEM_DEFAULT) | ||
| ExpressionEncoder[ComplexClass].resolve(Seq('a.long, 'b.struct(structType)), null) | ||
| } | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
required