Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

package org.apache.spark.sql.catalyst

import java.lang.reflect.Constructor

import org.apache.commons.lang3.reflect.ConstructorUtils

import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects._
Expand Down Expand Up @@ -781,6 +785,15 @@ object ScalaReflection extends ScalaReflection {
}
}

/**
* Finds an accessible constructor with compatible parameters. This is a more flexible search
* than the exact matching algorithm in `Class.getConstructor`. The first assignment-compatible
* matching constructor is returned. Otherwise, it returns `None`.
*/
def findConstructor(cls: Class[_], paramTypes: Seq[Class[_]]): Option[Constructor[_]] = {
Option(ConstructorUtils.getMatchingAccessibleConstructor(cls, paramTypes: _*))
}

/**
* Whether the fields of the given type is defined entirely by its constructor parameters.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -449,8 +449,32 @@ case class NewInstance(
childrenResolved && !needOuterPointer
}

override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
@transient private lazy val constructor: (Seq[AnyRef]) => Any = {
val paramTypes = ScalaReflection.expressionJavaClasses(arguments)
val getConstructor = (paramClazz: Seq[Class[_]]) => {
ScalaReflection.findConstructor(cls, paramClazz).getOrElse {
sys.error(s"Couldn't find a valid constructor on $cls")
}
}
outerPointer.map { p =>
val outerObj = p()
val d = outerObj.getClass +: paramTypes
val c = getConstructor(outerObj.getClass +: paramTypes)
(args: Seq[AnyRef]) => {
c.newInstance(outerObj +: args: _*)
}
}.getOrElse {
val c = getConstructor(paramTypes)
(args: Seq[AnyRef]) => {
c.newInstance(args: _*)
}
}
}

override def eval(input: InternalRow): Any = {
val argValues = arguments.map(_.eval(input))
constructor(argValues.map(_.asInstanceOf[AnyRef]))
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val javaType = CodeGenerator.javaType(dataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,20 @@ class InvokeTargetSubClass extends InvokeTargetClass {
override def binOp(e1: Int, e2: Double): Double = e1 - e2
}

// Tests for NewInstance
class Outer extends Serializable {
class Inner(val value: Int) {
override def hashCode(): Int = super.hashCode()
override def equals(other: Any): Boolean = {
if (other.isInstanceOf[Inner]) {
value == other.asInstanceOf[Inner].value
} else {
false
}
}
}
}

class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {

test("SPARK-16622: The returned value of the called method in Invoke can be null") {
Expand Down Expand Up @@ -383,6 +397,27 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}

test("SPARK-23584 NewInstance should support interpreted execution") {
// Normal case test
val newInst1 = NewInstance(
cls = classOf[GenericArrayData],
arguments = Literal.fromObject(List(1, 2, 3)) :: Nil,
propagateNull = false,
dataType = ArrayType(IntegerType),
outerPointer = None)
checkObjectExprEvaluation(newInst1, new GenericArrayData(List(1, 2, 3)))

// Inner class case test
val outerObj = new Outer()
val newInst2 = NewInstance(
cls = classOf[outerObj.Inner],
arguments = Literal(1) :: Nil,
propagateNull = false,
dataType = ObjectType(classOf[outerObj.Inner]),
outerPointer = Some(() => outerObj))
checkObjectExprEvaluation(newInst2, new outerObj.Inner(1))
}

test("LambdaVariable should support interpreted execution") {
def genSchema(dt: DataType): Seq[StructType] = {
Seq(StructType(StructField("col_1", dt, nullable = false) :: Nil),
Expand Down Expand Up @@ -421,6 +456,7 @@ class TestBean extends Serializable {
private var x: Int = 0

def setX(i: Int): Unit = x = i

def setNonPrimitive(i: AnyRef): Unit =
assert(i != null, "this setter should not be called with null.")
}