Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@
package org.apache.spark.sql.catalyst

import org.apache.spark.sql.catalyst.analysis.{UnresolvedExtractValue, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, ArrayData, DateTimeUtils}
import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -124,17 +123,46 @@ object ScalaReflection extends ScalaReflection {
path: Option[Expression]): Expression = ScalaReflectionLock.synchronized {

/** Returns the current path with a sub-field extracted. */
def addToPath(part: String): Expression = path
.map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
.getOrElse(UnresolvedAttribute(part))
def addToPath(part: String, dataType: DataType): Expression = {
val newPath = path
.map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
.getOrElse(UnresolvedAttribute(part))
upCastToExpectedType(newPath, dataType)
}

/** Returns the current path with a field at ordinal extracted. */
def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path
.map(p => GetInternalRowField(p, ordinal, dataType))
.getOrElse(BoundReference(ordinal, dataType, false))
def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = {
val newPath = path
.map(p => GetStructField(p, new StructField("", dataType), ordinal))
.getOrElse(BoundReference(ordinal, dataType, false))
upCastToExpectedType(newPath, dataType)
}

/** Returns the current path or `BoundReference`. */
def getPath: Expression = path.getOrElse(BoundReference(0, schemaFor(tpe).dataType, true))
def getPath: Expression = {
val dataType = schemaFor(tpe).dataType
path.getOrElse(upCastToExpectedType(BoundReference(0, dataType, true), dataType))
}

/**
* When we build the `fromRowExpression` for an encoder, we set up a lot of "unresolved" stuff
* and lost the required data type, which may lead to runtime error if the real type doesn't
* match the encoder's schema.
* For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type
* is [a: int, b: long], then we will hit runtime error and say that we can't construct class
* `Data` with int and long, because we lost the information that `b` should be a string.
*
* This method help us "remember" the require data type by adding a `UpCast`. Note that we
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: required

* don't need to cast struct type because there must be `UnresolvedExtractValue` or
* `GetStructField` wrapping it, and we will need to handle leaf type.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thus we only need to handle leaf types.?

*
* TODO: this only works if the real type is compatible with the encoder's schema, we should
* also handle error cases.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you saying type compatibility, is it like type promotion? Have we defined such rules for type promotion in Spark? Thanks

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if we want to automatically downcast where we could possibly truncate the values. Unlike an explicit cast, where the user is asking for it, I think this could be confusing. Consider the following:

scala> case class Data(value: Int)
scala> Seq(Int.MaxValue.toLong + 1).toDS().as[Data].collect()
res6: Array[Data] = Array(Data(-2147483648))

I think we at least want to warn, and probably just throw an error. If this is really what they want then they can cast explicitly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gatorsmile "type compatibility" means if we do type cast, the Cast operator can be resolved, see https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala#L35-L85

@marmbrus I also thought about this, maybe we can create a different cast operator and define the encoder related rules there?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stale comment?

*/
def upCastToExpectedType(expr: Expression, expected: DataType): Expression = expected match {
case _: StructType => expr
case _ => UpCast(expr, expected)
}

tpe match {
case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath
Expand Down Expand Up @@ -302,7 +330,7 @@ object ScalaReflection extends ScalaReflection {
if (cls.getName startsWith "scala.Tuple") {
constructorFor(fieldType, Some(addToPathOrdinal(i, dataType)))
} else {
constructorFor(fieldType, Some(addToPath(fieldName)))
constructorFor(fieldType, Some(addToPath(fieldName, dataType)))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ class Analyzer(
Batch("UDF", Once,
HandleNullInputsForUDF),
Batch("Cleanup", fixedPoint,
CleanupAliases)
CleanupAliases,
RemoveUpCast)
)

/**
Expand Down Expand Up @@ -1169,3 +1170,34 @@ object ComputeCurrentTime extends Rule[LogicalPlan] {
}
}
}

/**
* Replace the `UpCast` expression by `Cast`, and throw exceptions if the cast may truncate.
*/
object RemoveUpCast extends Rule[LogicalPlan] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe ResolveUpCasts

private def fail(from: DataType, to: DataType) = {
throw new AnalysisException(
s"Cannot up cast ${from.simpleString} to ${to.simpleString} as it may truncate")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be good to include the awesome path error messages that we use when we fail to resolve here as well. We might also suggest ways to fix it (i.e. "either add an explicit cast to the input data or choose a higher precision type in the target object").

}

private def checkNumericPrecedence(from: DataType, to: DataType): Boolean = {
val fromPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(from)
val toPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(to)
if (toPrecedence > 0 && fromPrecedence > toPrecedence) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

!(toPrecedence > 0 && fromPrecedence > toPrecedence)

false
} else {
true
}
}

def apply(plan: LogicalPlan): LogicalPlan = {
plan transformAllExpressions {
case UpCast(child, dataType) => (child.dataType, dataType) match {
case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) => fail(from, to)
case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) => fail(from, to)
case (from, to) if !checkNumericPrecedence(from, to) => fail(from, to)
case _ => Cast(child, dataType)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ object HiveTypeCoercion {

// See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types.
// The conversion for integral and floating point types have a linear widening hierarchy:
private val numericPrecedence =
private[sql] val numericPrecedence =
IndexedSeq(
ByteType,
ShortType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtract
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.types.{StructField, ObjectType, StructType}
Expand Down Expand Up @@ -235,12 +236,13 @@ case class ExpressionEncoder[T](

val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema))
val analyzedPlan = SimpleAnalyzer.execute(plan)
val optimizedPlan = SimplifyCasts(analyzedPlan)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should just use the full optimizer here. I guess for now it won't do anything, but since it should never change the answer and we might improve it later that might make more sense.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also thinking about if we should introduce sqlContext here, and use its analyzer and optimizer. For now our encoder resolution is case sensitive regardless of the CASE_SENSITIVE config.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not a SQLContext but a CatalystConfig would be reasonable. I wonder if it should be a different setting than SQL case sensitivity resolution?

On one hand, Scala/Java are always case sensitive so it seems reasonable to preserve that. On the other hand if you loading from something like hive it would be annoying to have to fix all the columns by hand.

@rxin, thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe encoders should be case sensitive all the time to begin with? It is programming language after all, which is case sensitive. If users complain, we can consider adding them in the future?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM


// In order to construct instances of inner classes (for example those declared in a REPL cell),
// we need an instance of the outer scope. This rule substitues those outer objects into
// expressions that are missing them by looking up the name in the SQLContexts `outerScopes`
// registry.
copy(fromRowExpression = analyzedPlan.expressions.head.children.head transform {
copy(fromRowExpression = optimizedPlan.expressions.head.children.head transform {
case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass =>
val outer = outerScopes.get(n.cls.getDeclaringClass.getName)
if (outer == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,7 @@ object Cast {
}

/** Cast the child expression to the target data type. */
case class Cast(child: Expression, dataType: DataType)
extends UnaryExpression with CodegenFallback {
case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {

override def toString: String = s"cast($child as ${dataType.simpleString})"

Expand Down Expand Up @@ -915,3 +914,9 @@ case class Cast(child: Expression, dataType: DataType)
"""
}
}

/**
* Cast the child expression to the target data type, but will throw error if the cast might
* truncate, e.g. long -> int, timestamp -> data.
*/
case class UpCast(child: Expression, dataType: DataType) extends UnaryExpression with Unevaluable
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ case class CreateStruct(children: Seq[Expression]) extends Expression {
case class CreateNamedStruct(children: Seq[Expression]) extends Expression {

/**
* Returns Aliased [[Expressions]] that could be used to construct a flattened version of this
* Returns Aliased [[Expression]]s that could be used to construct a flattened version of this
* StructType.
*/
def flatten: Seq[NamedExpression] = valExprs.zip(names).map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -548,19 +548,22 @@ case class MapGroups[K, T, U](

/** Factory for constructing new `CoGroup` nodes. */
object CoGroup {
def apply[K : Encoder, Left : Encoder, Right : Encoder, R : Encoder](
func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R],
def apply[Key, Left, Right, Result : Encoder](
func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result],
keyEnc: ExpressionEncoder[Key],
leftEnc: ExpressionEncoder[Left],
rightEnc: ExpressionEncoder[Right],
leftGroup: Seq[Attribute],
rightGroup: Seq[Attribute],
left: LogicalPlan,
right: LogicalPlan): CoGroup[K, Left, Right, R] = {
right: LogicalPlan): CoGroup[Key, Left, Right, Result] = {
CoGroup(
func,
encoderFor[K],
encoderFor[Left],
encoderFor[Right],
encoderFor[R],
encoderFor[R].schema.toAttributes,
keyEnc,
leftEnc,
rightEnc,
encoderFor[Result],
encoderFor[Result].schema.toAttributes,
leftGroup,
rightGroup,
left,
Expand All @@ -574,10 +577,10 @@ object CoGroup {
*/
case class CoGroup[K, Left, Right, R](
func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R],
kEncoder: ExpressionEncoder[K],
keyEnc: ExpressionEncoder[K],
leftEnc: ExpressionEncoder[Left],
rightEnc: ExpressionEncoder[Right],
rEncoder: ExpressionEncoder[R],
resultEnc: ExpressionEncoder[R],
output: Seq[Attribute],
leftGroup: Seq[Attribute],
rightGroup: Seq[Attribute],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,18 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType {
case _ => false
}

/**
* Returns whether this DecimalType is tighter than `other`. If yes, it means `this`
* can be casted into `other` safely without losing any precision or range.
*/
private[sql] def isTighterThan(other: DataType): Boolean = other match {
case dt: DecimalType =>
(precision - scale) <= (dt.precision - dt.scale) && scale <= dt.scale
case dt: IntegralType =>
isTighterThan(DecimalType.forType(dt))
case _ => false
}

/**
* The default size of a value of the DecimalType is 4096 bytes.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.encoders

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.types._

case class StringLongClass(a: String, b: Long)

case class StringIntClass(a: String, b: Int)

case class ComplexClass(a: Long, b: StringLongClass)

class EncoderResolveSuite extends PlanTest {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: this should probably be EncoderResolutionSuite

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would also be very helpful to add a bunch of tests like castFail[Long, Int], castSuccess[Int, Long]. This will both make sure we don't change the rules in the future and help me audit to make sure the current auto conversions make sense.

test("real type doesn't match encoder schema but they are compatible: product") {
val encoder = ExpressionEncoder[StringLongClass]
val cls = classOf[StringLongClass]

var attrs = Seq('a.string, 'b.int)
var fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
var expected: Expression = NewInstance(
cls,
toExternalString('a.string) :: 'b.int.cast(LongType) :: Nil,
false,
ObjectType(cls))
compareExpressions(fromRowExpr, expected)

attrs = Seq('a.int, 'b.long)
fromRowExpr = encoder.resolve(attrs, null).fromRowExpression
expected = NewInstance(
cls,
toExternalString('a.int.cast(StringType)) :: 'b.long :: Nil,
false,
ObjectType(cls))
compareExpressions(fromRowExpr, expected)
}

test("real type doesn't match encoder schema but they are compatible: nested product") {
val encoder = ExpressionEncoder[ComplexClass]
val innerCls = classOf[StringLongClass]
val cls = classOf[ComplexClass]

val structType = new StructType().add("a", IntegerType).add("b", LongType)
val attrs = Seq('a.int, 'b.struct(structType))
val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
val expected: Expression = NewInstance(
cls,
Seq(
'a.int.cast(LongType),
If(
'b.struct(structType).isNull,
Literal.create(null, ObjectType(innerCls)),
NewInstance(
innerCls,
Seq(
toExternalString(GetStructField(
'b.struct(structType),
structType(0),
0).cast(StringType)),
GetStructField(
'b.struct(structType),
structType(1),
1)),
false,
ObjectType(innerCls))
)),
false,
ObjectType(cls))
compareExpressions(fromRowExpr, expected)
}

test("real type doesn't match encoder schema but they are compatible: tupled encoder") {
val encoder = ExpressionEncoder.tuple(
ExpressionEncoder[StringLongClass],
ExpressionEncoder[Long])
val cls = classOf[StringLongClass]

val structType = new StructType().add("a", StringType).add("b", ByteType)
val attrs = Seq('a.struct(structType), 'b.int)
val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
val expected: Expression = NewInstance(
classOf[Tuple2[_, _]],
Seq(
NewInstance(
cls,
Seq(
toExternalString(GetStructField(
'a.struct(structType),
structType(0),
0)),
GetStructField(
'a.struct(structType),
structType(1),
1).cast(LongType)),
false,
ObjectType(cls)),
'b.int.cast(LongType)),
false,
ObjectType(classOf[Tuple2[_, _]]))
compareExpressions(fromRowExpr, expected)
}

private def toExternalString(e: Expression): Expression = {
Invoke(e, "toString", ObjectType(classOf[String]), Nil)
}

test("throw exception if real type is not compatible with encoder schema") {
intercept[AnalysisException] {
ExpressionEncoder[StringIntClass].resolve(Seq('a.string, 'b.long), null)
}

intercept[AnalysisException] {
val structType = new StructType().add("a", StringType).add("b", DecimalType.SYSTEM_DEFAULT)
ExpressionEncoder[ComplexClass].resolve(Seq('a.long, 'b.struct(structType)), null)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -304,11 +304,13 @@ class GroupedDataset[K, V] private[sql](
def cogroup[U, R : Encoder](
other: GroupedDataset[K, U])(
f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
implicit def uEnc: Encoder[U] = other.unresolvedTEncoder
new Dataset[R](
sqlContext,
CoGroup(
f,
resolvedKEncoder,
this.resolvedTEncoder,
other.resolvedTEncoder,
this.groupingAttributes,
other.groupingAttributes,
this.logicalPlan,
Expand Down
Loading