Skip to content

Commit cc88d7f

Browse files
Fangshi Licloud-fan
authored andcommitted
[SPARK-24216][SQL] Spark TypedAggregateExpression uses getSimpleName that is not safe in scala
## What changes were proposed in this pull request? When user create a aggregator object in scala and pass the aggregator to Spark Dataset's agg() method, Spark's will initialize TypedAggregateExpression with the nodeName field as aggregator.getClass.getSimpleName. However, getSimpleName is not safe in scala environment, depending on how user creates the aggregator object. For example, if the aggregator class full qualified name is "com.my.company.MyUtils$myAgg$2$", the getSimpleName will throw java.lang.InternalError "Malformed class name". This has been reported in scalatest scalatest/scalatest#1044 and discussed in many scala upstream jiras such as SI-8110, SI-5425. To fix this issue, we follow the solution in scalatest/scalatest#1044 to add safer version of getSimpleName as a util method, and TypedAggregateExpression will invoke this util method rather than getClass.getSimpleName. ## How was this patch tested? added unit test Author: Fangshi Li <[email protected]> Closes #21276 from fangshil/SPARK-24216.
1 parent f0ef1b3 commit cc88d7f

File tree

6 files changed

+89
-6
lines changed

6 files changed

+89
-6
lines changed

core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,12 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable {
200200
}
201201

202202
override def toString: String = {
203+
// getClass.getSimpleName can cause Malformed class name error,
204+
// call safer `Utils.getSimpleName` instead
203205
if (metadata == null) {
204-
"Un-registered Accumulator: " + getClass.getSimpleName
206+
"Un-registered Accumulator: " + Utils.getSimpleName(getClass)
205207
} else {
206-
getClass.getSimpleName + s"(id: $id, name: $name, value: $value)"
208+
Utils.getSimpleName(getClass) + s"(id: $id, name: $name, value: $value)"
207209
}
208210
}
209211
}

core/src/main/scala/org/apache/spark/util/Utils.scala

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.util
1919

2020
import java.io._
2121
import java.lang.{Byte => JByte}
22+
import java.lang.InternalError
2223
import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInfo}
2324
import java.lang.reflect.InvocationTargetException
2425
import java.math.{MathContext, RoundingMode}
@@ -1820,7 +1821,7 @@ private[spark] object Utils extends Logging {
18201821

18211822
/** Return the class name of the given object, removing all dollar signs */
18221823
def getFormattedClassName(obj: AnyRef): String = {
1823-
obj.getClass.getSimpleName.replace("$", "")
1824+
getSimpleName(obj.getClass).replace("$", "")
18241825
}
18251826

18261827
/**
@@ -2715,6 +2716,62 @@ private[spark] object Utils extends Logging {
27152716
HashCodes.fromBytes(secretBytes).toString()
27162717
}
27172718

2719+
/**
2720+
* Safer than Class obj's getSimpleName which may throw Malformed class name error in scala.
2721+
* This method mimicks scalatest's getSimpleNameOfAnObjectsClass.
2722+
*/
2723+
def getSimpleName(cls: Class[_]): String = {
2724+
try {
2725+
return cls.getSimpleName
2726+
} catch {
2727+
case err: InternalError => return stripDollars(stripPackages(cls.getName))
2728+
}
2729+
}
2730+
2731+
/**
2732+
* Remove the packages from full qualified class name
2733+
*/
2734+
private def stripPackages(fullyQualifiedName: String): String = {
2735+
fullyQualifiedName.split("\\.").takeRight(1)(0)
2736+
}
2737+
2738+
/**
2739+
* Remove trailing dollar signs from qualified class name,
2740+
* and return the trailing part after the last dollar sign in the middle
2741+
*/
2742+
private def stripDollars(s: String): String = {
2743+
val lastDollarIndex = s.lastIndexOf('$')
2744+
if (lastDollarIndex < s.length - 1) {
2745+
// The last char is not a dollar sign
2746+
if (lastDollarIndex == -1 || !s.contains("$iw")) {
2747+
// The name does not have dollar sign or is not an intepreter
2748+
// generated class, so we should return the full string
2749+
s
2750+
} else {
2751+
// The class name is intepreter generated,
2752+
// return the part after the last dollar sign
2753+
// This is the same behavior as getClass.getSimpleName
2754+
s.substring(lastDollarIndex + 1)
2755+
}
2756+
}
2757+
else {
2758+
// The last char is a dollar sign
2759+
// Find last non-dollar char
2760+
val lastNonDollarChar = s.reverse.find(_ != '$')
2761+
lastNonDollarChar match {
2762+
case None => s
2763+
case Some(c) =>
2764+
val lastNonDollarIndex = s.lastIndexOf(c)
2765+
if (lastNonDollarIndex == -1) {
2766+
s
2767+
} else {
2768+
// Strip the trailing dollar signs
2769+
// Invoke stripDollars again to get the simple name
2770+
stripDollars(s.substring(0, lastNonDollarIndex + 1))
2771+
}
2772+
}
2773+
}
2774+
}
27182775
}
27192776

27202777
private[util] object CallerContext extends Logging {

core/src/test/scala/org/apache/spark/util/UtilsSuite.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,6 +1168,22 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging {
11681168
Utils.checkAndGetK8sMasterUrl("k8s://foo://host:port")
11691169
}
11701170
}
1171+
1172+
object MalformedClassObject {
1173+
class MalformedClass
1174+
}
1175+
1176+
test("Safe getSimpleName") {
1177+
// getSimpleName on class of MalformedClass will result in error: Malformed class name
1178+
// Utils.getSimpleName works
1179+
val err = intercept[java.lang.InternalError] {
1180+
classOf[MalformedClassObject.MalformedClass].getSimpleName
1181+
}
1182+
assert(err.getMessage === "Malformed class name")
1183+
1184+
assert(Utils.getSimpleName(classOf[MalformedClassObject.MalformedClass]) ===
1185+
"UtilsSuite$MalformedClassObject$MalformedClass")
1186+
}
11711187
}
11721188

11731189
private class SimpleExtension

mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import org.apache.spark.ml.{Estimator, Model}
3030
import org.apache.spark.ml.param.Param
3131
import org.apache.spark.rdd.RDD
3232
import org.apache.spark.sql.Dataset
33+
import org.apache.spark.util.Utils
3334

3435
/**
3536
* A small wrapper that defines a training session for an estimator, and some methods to log
@@ -47,7 +48,9 @@ private[spark] class Instrumentation[E <: Estimator[_]] private (
4748

4849
private val id = UUID.randomUUID()
4950
private val prefix = {
50-
val className = estimator.getClass.getSimpleName
51+
// estimator.getClass.getSimpleName can cause Malformed class name error,
52+
// call safer `Utils.getSimpleName` instead
53+
val className = Utils.getSimpleName(estimator.getClass)
5154
s"$className-${estimator.uid}-${dataset.hashCode()}-$id: "
5255
}
5356

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection
2929
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
3030
import org.apache.spark.sql.expressions.Aggregator
3131
import org.apache.spark.sql.types._
32+
import org.apache.spark.util.Utils
3233

3334
object TypedAggregateExpression {
3435
def apply[BUF : Encoder, OUT : Encoder](
@@ -109,7 +110,9 @@ trait TypedAggregateExpression extends AggregateFunction {
109110
s"$nodeName($input)"
110111
}
111112

112-
override def nodeName: String = aggregator.getClass.getSimpleName.stripSuffix("$")
113+
// aggregator.getClass.getSimpleName can cause Malformed class name error,
114+
// call safer `Utils.getSimpleName` instead
115+
override def nodeName: String = Utils.getSimpleName(aggregator.getClass).stripSuffix("$");
113116
}
114117

115118
// TODO: merge these 2 implementations once we refactor the `AggregateFunction` interface.

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ trait DataSourceV2StringFormat {
5353

5454
private def sourceName: String = source match {
5555
case registered: DataSourceRegister => registered.shortName()
56-
case _ => source.getClass.getSimpleName.stripSuffix("$")
56+
// source.getClass.getSimpleName can cause Malformed class name error,
57+
// call safer `Utils.getSimpleName` instead
58+
case _ => Utils.getSimpleName(source.getClass)
5759
}
5860

5961
def metadataString: String = {

0 commit comments

Comments
 (0)