Skip to content

Commit 5acaccd

Browse files
committed
Properly call serializer's constructors.
1 parent 2a8d75a commit 5acaccd

File tree

2 files changed

+27
-10
lines changed

2 files changed

+27
-10
lines changed

core/src/main/scala/org/apache/spark/SparkEnv.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,15 @@ object SparkEnv extends Logging {
138138
// defaultClassName if the property is not set, and return it as a T
139139
def instantiateClass[T](propertyName: String, defaultClassName: String): T = {
140140
val name = conf.get(propertyName, defaultClassName)
141-
Class.forName(name, true, classLoader).newInstance().asInstanceOf[T]
141+
val cls = Class.forName(name, true, classLoader)
142+
// First try with the constructor that takes SparkConf. If we can't find one,
143+
// use a no-arg constructor instead.
144+
try {
145+
cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T]
146+
} catch {
147+
case _: NoSuchMethodException =>
148+
cls.getConstructor().newInstance().asInstanceOf[T]
149+
}
142150
}
143151

144152
val serializer = instantiateClass[Serializer](

core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,10 @@ import java.nio.ByteBuffer
2323
import org.apache.spark.SparkConf
2424
import org.apache.spark.util.ByteBufferInputStream
2525

26-
private[spark] class JavaSerializationStream(out: OutputStream, conf: SparkConf)
26+
private[spark] class JavaSerializationStream(out: OutputStream, counterReset: Int)
2727
extends SerializationStream {
28-
val objOut = new ObjectOutputStream(out)
29-
var counter = 0
30-
val counterReset = conf.getInt("spark.serializer.objectStreamReset", 10000)
28+
private val objOut = new ObjectOutputStream(out)
29+
private var counter = 0
3130

3231
/**
3332
* Calling reset to avoid memory leak:
@@ -51,7 +50,7 @@ private[spark] class JavaSerializationStream(out: OutputStream, conf: SparkConf)
5150

5251
private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoader)
5352
extends DeserializationStream {
54-
val objIn = new ObjectInputStream(in) {
53+
private val objIn = new ObjectInputStream(in) {
5554
override def resolveClass(desc: ObjectStreamClass) =
5655
Class.forName(desc.getName, false, loader)
5756
}
@@ -60,7 +59,7 @@ extends DeserializationStream {
6059
def close() { objIn.close() }
6160
}
6261

63-
private[spark] class JavaSerializerInstance(conf: SparkConf) extends SerializerInstance {
62+
private[spark] class JavaSerializerInstance(counterReset: Int) extends SerializerInstance {
6463
def serialize[T](t: T): ByteBuffer = {
6564
val bos = new ByteArrayOutputStream()
6665
val out = serializeStream(bos)
@@ -82,7 +81,7 @@ private[spark] class JavaSerializerInstance(conf: SparkConf) extends SerializerI
8281
}
8382

8483
def serializeStream(s: OutputStream): SerializationStream = {
85-
new JavaSerializationStream(s, conf)
84+
new JavaSerializationStream(s, counterReset)
8685
}
8786

8887
def deserializeStream(s: InputStream): DeserializationStream = {
@@ -97,6 +96,16 @@ private[spark] class JavaSerializerInstance(conf: SparkConf) extends SerializerI
9796
/**
9897
* A Spark serializer that uses Java's built-in serialization.
9998
*/
100-
class JavaSerializer(conf: SparkConf) extends Serializer with Serializable {
101-
def newInstance(): SerializerInstance = new JavaSerializerInstance(conf)
99+
class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable {
100+
private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 10000)
101+
102+
def newInstance(): SerializerInstance = new JavaSerializerInstance(counterReset)
103+
104+
override def writeExternal(out: ObjectOutput) {
105+
out.writeInt(counterReset)
106+
}
107+
108+
override def readExternal(in: ObjectInput) {
109+
counterReset = in.readInt()
110+
}
102111
}

0 commit comments

Comments
 (0)