Skip to content

Commit 6ce1b16

Browse files
committed
[SPARK-35278][SQL] Invoke should find the method with correct number of parameters
### What changes were proposed in this pull request? This patch fixes `Invoke` expression when the target object has more than one method with the given method name. ### Why are the changes needed? `Invoke` will find out the method on the target object with given method name. If there are more than one method with the name, currently it is undeterministic which method will be used. We should add the condition of parameter number when finding the method. ### Does this PR introduce _any_ user-facing change? Yes, fixed a bug when using `Invoke` on a object where more than one method with the given method name. ### How was this patch tested? Unit test. Closes #32404 from viirya/verify-invoke-param-len. Authored-by: Liang-Chi Hsieh <[email protected]> Signed-off-by: Liang-Chi Hsieh <[email protected]>
1 parent 72e238a commit 6ce1b16

File tree

2 files changed

+55
-5
lines changed

2 files changed

+55
-5
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -326,11 +326,30 @@ case class Invoke(
326326

327327
@transient lazy val method = targetObject.dataType match {
328328
case ObjectType(cls) =>
329-
val m = cls.getMethods.find(_.getName == encodedFunctionName)
330-
if (m.isEmpty) {
331-
sys.error(s"Couldn't find $encodedFunctionName on $cls")
332-
} else {
333-
m
329+
// Looking with function name + argument classes first.
330+
try {
331+
Some(cls.getMethod(encodedFunctionName, argClasses: _*))
332+
} catch {
333+
case _: NoSuchMethodException =>
334+
// For some cases, e.g. arg class is Object, `getMethod` cannot find the method.
335+
// We look at function name + argument length
336+
val m = cls.getMethods.filter { m =>
337+
m.getName == encodedFunctionName && m.getParameterCount == arguments.length
338+
}
339+
if (m.isEmpty) {
340+
sys.error(s"Couldn't find $encodedFunctionName on $cls")
341+
} else if (m.length > 1) {
342+
// More than one matched method signature. Exclude synthetic one, e.g. generic one.
343+
val realMethods = m.filter(!_.isSynthetic)
344+
if (realMethods.length > 1) {
345+
// Ambiguous case, we don't know which method to choose, just fail it.
346+
sys.error(s"Found ${realMethods.length} $encodedFunctionName on $cls")
347+
} else {
348+
Some(realMethods.head)
349+
}
350+
} else {
351+
Some(m.head)
352+
}
334353
}
335354
case _ => None
336355
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,29 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
618618
checkExceptionInExpression[ArithmeticException](
619619
StaticInvoke(mathCls, IntegerType, "addExact", Seq(Literal(Int.MaxValue), Literal(1))), "")
620620
}
621+
622+
test("SPARK-35278: invoke should find method with correct number of parameters") {
623+
val strClsType = ObjectType(classOf[String])
624+
checkExceptionInExpression[StringIndexOutOfBoundsException](
625+
Invoke(Literal("a", strClsType), "substring", strClsType, Seq(Literal(3))), "")
626+
627+
checkObjectExprEvaluation(
628+
Invoke(Literal("a", strClsType), "substring", strClsType, Seq(Literal(0))), "a")
629+
630+
checkExceptionInExpression[StringIndexOutOfBoundsException](
631+
Invoke(Literal("a", strClsType), "substring", strClsType, Seq(Literal(0), Literal(3))), "")
632+
633+
checkObjectExprEvaluation(
634+
Invoke(Literal("a", strClsType), "substring", strClsType, Seq(Literal(0), Literal(1))), "a")
635+
}
636+
637+
test("SPARK-35278: invoke should correctly invoke override method") {
638+
val clsType = ObjectType(classOf[ConcreteClass])
639+
val obj = new ConcreteClass
640+
641+
checkObjectExprEvaluation(
642+
Invoke(Literal(obj, clsType), "testFunc", IntegerType, Seq(Literal(1))), 0)
643+
}
621644
}
622645

623646
class TestBean extends Serializable {
@@ -628,3 +651,11 @@ class TestBean extends Serializable {
628651
def setNonPrimitive(i: AnyRef): Unit =
629652
assert(i != null, "this setter should not be called with null.")
630653
}
654+
655+
abstract class BaseClass[T] {
656+
def testFunc(param: T): T
657+
}
658+
659+
class ConcreteClass extends BaseClass[Int] with Serializable {
660+
override def testFunc(param: Int): Int = param - 1
661+
}

0 commit comments

Comments
 (0)