Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,19 @@ object ScalaReflection extends ScalaReflection {
}
}

def javaBoxedType(dt: DataType): Class[_] = dt match {
case _: DecimalType => classOf[Decimal]
case BinaryType => classOf[Array[Byte]]
case StringType => classOf[UTF8String]
case CalendarIntervalType => classOf[CalendarInterval]
case _: StructType => classOf[InternalRow]
case _: ArrayType => classOf[ArrayType]
case _: MapType => classOf[MapType]
case udt: UserDefinedType[_] => javaBoxedType(udt.sqlType)
case ObjectType(cls) => cls
case _ => ScalaReflection.typeBoxedJavaMapping.getOrElse(dt, classOf[java.lang.Object])
}

def expressionJavaClasses(arguments: Seq[Expression]): Seq[Class[_]] = {
if (arguments != Nil) {
arguments.map(e => dataTypeJavaClass(e.dataType))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import org.apache.spark.util.Utils

/**
Expand Down Expand Up @@ -1670,13 +1671,36 @@ case class ValidateExternalType(child: Expression, expected: DataType)

override def nullable: Boolean = child.nullable

override def dataType: DataType = RowEncoder.externalDataTypeForInput(expected)

override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
override val dataType: DataType = RowEncoder.externalDataTypeForInput(expected)

private val errMsg = s" is not a valid external type for schema of ${expected.simpleString}"

private lazy val checkType: (Any) => Boolean = expected match {
case _: DecimalType =>
(value: Any) => {
value.isInstanceOf[java.math.BigDecimal] || value.isInstanceOf[scala.math.BigDecimal] ||
value.isInstanceOf[Decimal]
}
case _: ArrayType =>
(value: Any) => {
value.getClass.isArray || value.isInstanceOf[Seq[_]]
}
case _ =>
val dataTypeClazz = ScalaReflection.javaBoxedType(dataType)
(value: Any) => {
dataTypeClazz.isInstance(value)
}
}

override def eval(input: InternalRow): Any = {
val result = child.eval(input)
if (checkType(result)) {
result
} else {
throw new RuntimeException(s"${result.getClass.getName}$errMsg")
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
// Use unnamed reference that doesn't create a local field here to reduce the number of fields
// because errMsgField is used only when the type doesn't match.
Expand All @@ -1689,7 +1713,7 @@ case class ValidateExternalType(child: Expression, expected: DataType)
Seq(classOf[java.math.BigDecimal], classOf[scala.math.BigDecimal], classOf[Decimal])
.map(cls => s"$obj instanceof ${cls.getName}").mkString(" || ")
case _: ArrayType =>
s"$obj instanceof ${classOf[Seq[_]].getName} || $obj.getClass().isArray()"
s"$obj.getClass().isArray() || $obj instanceof ${classOf[Seq[_]].getName}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need to change codegen implementation?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the @rednaxelafx 's comment, I changed this codegen, too. #20757 (comment)

case _ =>
s"$obj instanceof ${CodeGenerator.boxedType(dataType)}"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, Generic
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}

class InvokeTargetClass extends Serializable {
def filterInt(e: Any): Any = e.asInstanceOf[Int] > 0
Expand Down Expand Up @@ -296,7 +296,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true)
val getRowField = GetExternalRowField(inputObject, index = 0, fieldName = "c0")
Seq((Row(1), 1), (Row(3), 3)).foreach { case (input, expected) =>
checkEvaluation(getRowField, expected, InternalRow.fromSeq(Seq(input)))
checkObjectExprEvaluation(getRowField, expected, InternalRow.fromSeq(Seq(input)))
}

// If an input row or a field are null, a runtime exception will be thrown
Expand Down Expand Up @@ -472,6 +472,35 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val deserializer = toMapExpr.copy(inputData = Literal.create(data))
checkObjectExprEvaluation(deserializer, expected = data)
}

test("SPARK-23595 ValidateExternalType should support interpreted execution") {
val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true)
Seq(
(true, BooleanType),
(2.toByte, ByteType),
(5.toShort, ShortType),
(23, IntegerType),
(61L, LongType),
(1.0f, FloatType),
(10.0, DoubleType),
("abcd".getBytes, BinaryType),
("abcd", StringType),
(BigDecimal.valueOf(10), DecimalType.IntDecimal),
(CalendarInterval.fromString("interval 3 day"), CalendarIntervalType),
(java.math.BigDecimal.valueOf(10), DecimalType.BigIntDecimal),
(Array(3, 2, 1), ArrayType(IntegerType))
).foreach { case (input, dt) =>
val validateType = ValidateExternalType(
GetExternalRowField(inputObject, index = 0, fieldName = "c0"), dt)
checkObjectExprEvaluation(validateType, input, InternalRow.fromSeq(Seq(Row(input))))
}

checkExceptionInExpression[RuntimeException](
ValidateExternalType(
GetExternalRowField(inputObject, index = 0, fieldName = "c0"), DoubleType),
InternalRow.fromSeq(Seq(Row(1))),
"java.lang.Integer is not a valid external type for schema of double")
}
}

class TestBean extends Serializable {
Expand Down