Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,50 @@ class KryoSerializer(conf: SparkConf)
private val useUnsafe = conf.get(KRYO_USE_UNSAFE)
private val usePool = conf.get(KRYO_USE_POOL)

// classForName() is expensive in case the class is not found, so we filter the list of
// SQL / ML / MLlib classes once and then re-use that filtered list in newInstance() calls.
private lazy val loadableClasses: Seq[Class[_]] = {
Copy link
Contributor

@JoshRosen JoshRosen Jun 19, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be moved into a private[serializer] field in a object KryoSerializer companion? Now that I look at this again, I'm worried that it'll be serialized as part of KryoSerializer itself, since I think the serializer itself is serialized as part of ShuffleDependency. I don't think that's a huge deal but we could probably shave off some additional work with that extra step.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's a commit with that change: JoshRosen@c8680f9

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I just pushed directly to your branch using GitHub's new "allow edits from maintainers" feature. Hope you don't mind 😄!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Feel free to push it.

Seq(
"org.apache.spark.sql.catalyst.expressions.UnsafeRow",
"org.apache.spark.sql.catalyst.expressions.UnsafeArrayData",
"org.apache.spark.sql.catalyst.expressions.UnsafeMapData",

"org.apache.spark.ml.attribute.Attribute",
"org.apache.spark.ml.attribute.AttributeGroup",
"org.apache.spark.ml.attribute.BinaryAttribute",
"org.apache.spark.ml.attribute.NominalAttribute",
"org.apache.spark.ml.attribute.NumericAttribute",

"org.apache.spark.ml.feature.Instance",
"org.apache.spark.ml.feature.LabeledPoint",
"org.apache.spark.ml.feature.OffsetInstance",
"org.apache.spark.ml.linalg.DenseMatrix",
"org.apache.spark.ml.linalg.DenseVector",
"org.apache.spark.ml.linalg.Matrix",
"org.apache.spark.ml.linalg.SparseMatrix",
"org.apache.spark.ml.linalg.SparseVector",
"org.apache.spark.ml.linalg.Vector",
"org.apache.spark.ml.stat.distribution.MultivariateGaussian",
"org.apache.spark.ml.tree.impl.TreePoint",
"org.apache.spark.mllib.clustering.VectorWithNorm",
"org.apache.spark.mllib.linalg.DenseMatrix",
"org.apache.spark.mllib.linalg.DenseVector",
"org.apache.spark.mllib.linalg.Matrix",
"org.apache.spark.mllib.linalg.SparseMatrix",
"org.apache.spark.mllib.linalg.SparseVector",
"org.apache.spark.mllib.linalg.Vector",
"org.apache.spark.mllib.regression.LabeledPoint",
"org.apache.spark.mllib.stat.distribution.MultivariateGaussian"
).flatMap { name =>
try {
Some[Class[_]](Utils.classForName(name))
} catch {
case NonFatal(_) => None // do nothing
case _: NoClassDefFoundError if Utils.isTesting => None // See SPARK-23422.
}
}
}

def newKryoOutput(): KryoOutput =
if (useUnsafe) {
new KryoUnsafeOutput(bufferSize, math.max(bufferSize, maxBufferSize))
Expand Down Expand Up @@ -212,40 +256,8 @@ class KryoSerializer(conf: SparkConf)

// We can't load those class directly in order to avoid unnecessary jar dependencies.
// We load them safely, ignore it if the class not found.
Seq(
"org.apache.spark.sql.catalyst.expressions.UnsafeRow",
"org.apache.spark.sql.catalyst.expressions.UnsafeArrayData",
"org.apache.spark.sql.catalyst.expressions.UnsafeMapData",

"org.apache.spark.ml.attribute.Attribute",
"org.apache.spark.ml.attribute.AttributeGroup",
"org.apache.spark.ml.attribute.BinaryAttribute",
"org.apache.spark.ml.attribute.NominalAttribute",
"org.apache.spark.ml.attribute.NumericAttribute",

"org.apache.spark.ml.feature.Instance",
"org.apache.spark.ml.feature.LabeledPoint",
"org.apache.spark.ml.feature.OffsetInstance",
"org.apache.spark.ml.linalg.DenseMatrix",
"org.apache.spark.ml.linalg.DenseVector",
"org.apache.spark.ml.linalg.Matrix",
"org.apache.spark.ml.linalg.SparseMatrix",
"org.apache.spark.ml.linalg.SparseVector",
"org.apache.spark.ml.linalg.Vector",
"org.apache.spark.ml.stat.distribution.MultivariateGaussian",
"org.apache.spark.ml.tree.impl.TreePoint",
"org.apache.spark.mllib.clustering.VectorWithNorm",
"org.apache.spark.mllib.linalg.DenseMatrix",
"org.apache.spark.mllib.linalg.DenseVector",
"org.apache.spark.mllib.linalg.Matrix",
"org.apache.spark.mllib.linalg.SparseMatrix",
"org.apache.spark.mllib.linalg.SparseVector",
"org.apache.spark.mllib.linalg.Vector",
"org.apache.spark.mllib.regression.LabeledPoint",
"org.apache.spark.mllib.stat.distribution.MultivariateGaussian"
).foreach { name =>
loadableClasses.foreach { clazz =>
try {
val clazz = Utils.classForName(name)
kryo.register(clazz)
} catch {
case NonFatal(_) => // do nothing
Expand Down