Skip to content
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 0 additions & 100 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@ import java.net._
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}

import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.language.existentials

import com.google.common.base.Charsets.UTF_8
import net.razorvine.pickle.{Pickler, Unpickler}

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.compress.CompressionCodec
Expand Down Expand Up @@ -746,104 +744,6 @@ private[spark] object PythonRDD extends Logging {
converted.saveAsHadoopDataset(new JobConf(conf))
}
}


/**
* Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions).
*/
@deprecated("PySpark does not use it anymore", "1.1")
def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {
pyRDD.rdd.mapPartitions { iter =>
val unpickle = new Unpickler
SerDeUtil.initialize()
iter.flatMap { row =>
unpickle.loads(row) match {
// in case of objects are pickled in batch mode
case objs: JArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap)
// not in batch mode
case obj: JMap[String @unchecked, _] => Seq(obj.toMap)
}
}
}
}

/**
* Convert an RDD of serialized Python tuple to Array (no recursive conversions).
* It is only used by pyspark.sql.
*/
def pythonToJavaArray(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Array[_]] = {

def toArray(obj: Any): Array[_] = {
obj match {
case objs: JArrayList[_] =>
objs.toArray
case obj if obj.getClass.isArray =>
obj.asInstanceOf[Array[_]].toArray
}
}

pyRDD.rdd.mapPartitions { iter =>
val unpickle = new Unpickler
iter.flatMap { row =>
val obj = unpickle.loads(row)
if (batched) {
obj.asInstanceOf[JArrayList[_]].map(toArray)
} else {
Seq(toArray(obj))
}
}
}.toJavaRDD()
}

private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] {
private val pickle = new Pickler()
private var batch = 1
private val buffer = new mutable.ArrayBuffer[Any]

override def hasNext(): Boolean = iter.hasNext

override def next(): Array[Byte] = {
while (iter.hasNext && buffer.length < batch) {
buffer += iter.next()
}
val bytes = pickle.dumps(buffer.toArray)
val size = bytes.length
// let 1M < size < 10M
if (size < 1024 * 1024) {
batch *= 2
} else if (size > 1024 * 1024 * 10 && batch > 1) {
batch /= 2
}
buffer.clear()
bytes
}
}

/**
* Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
* PySpark.
*/
def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = {
jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) }
}

/**
* Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark.
*/
def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = {
pyRDD.rdd.mapPartitions { iter =>
SerDeUtil.initialize()
val unpickle = new Unpickler
iter.flatMap { row =>
val obj = unpickle.loads(row)
if (batched) {
obj.asInstanceOf[JArrayList[_]].asScala
} else {
Seq(obj)
}
}
}.toJavaRDD()
}
}

private
Expand Down
121 changes: 90 additions & 31 deletions core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@
package org.apache.spark.api.python

import java.nio.ByteOrder
import java.util.{ArrayList => JArrayList}

import org.apache.spark.api.java.JavaRDD

import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.Failure
import scala.util.Try

Expand Down Expand Up @@ -89,6 +94,73 @@ private[spark] object SerDeUtil extends Logging {
}
initialize()


/**
* Convert an RDD of Java objects to Array (no recursive conversions).
* It is only used by pyspark.sql.
*/
def toJavaArray(jrdd: JavaRDD[Any]): JavaRDD[Array[_]] = {
jrdd.rdd.map {
case objs: JArrayList[_] =>
objs.toArray
case obj if obj.getClass.isArray =>
obj.asInstanceOf[Array[_]].toArray
}.toJavaRDD()
}

/**
* Choose batch size based on size of objects
*/
private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] {
private val pickle = new Pickler()
private var batch = 1
private val buffer = new mutable.ArrayBuffer[Any]

override def hasNext: Boolean = iter.hasNext

override def next(): Array[Byte] = {
while (iter.hasNext && buffer.length < batch) {
buffer += iter.next()
}
val bytes = pickle.dumps(buffer.toArray)
val size = bytes.length
// let 1M < size < 10M
if (size < 1024 * 1024) {
batch *= 2
} else if (size > 1024 * 1024 * 10 && batch > 1) {
batch /= 2
}
buffer.clear()
bytes
}
}

/**
* Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
* PySpark.
*/
private[spark] def javaToPython(jRDD: JavaRDD[_]): JavaRDD[Array[Byte]] = {
jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) }
}

/**
* Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark.
*/
def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = {
pyRDD.rdd.mapPartitions { iter =>
initialize()
val unpickle = new Unpickler
iter.flatMap { row =>
val obj = unpickle.loads(row)
if (batched) {
obj.asInstanceOf[JArrayList[_]].asScala
} else {
Seq(obj)
}
}
}.toJavaRDD()
}

private def checkPickle(t: (Any, Any)): (Boolean, Boolean) = {
val pickle = new Pickler
val kt = Try {
Expand Down Expand Up @@ -128,54 +200,41 @@ private[spark] object SerDeUtil extends Logging {
*/
def pairRDDToPython(rdd: RDD[(Any, Any)], batchSize: Int): RDD[Array[Byte]] = {
val (keyFailed, valueFailed) = checkPickle(rdd.first())

rdd.mapPartitions { iter =>
val pickle = new Pickler
val cleaned = iter.map { case (k, v) =>
val key = if (keyFailed) k.toString else k
val value = if (valueFailed) v.toString else v
Array[Any](key, value)
}
if (batchSize > 1) {
cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched)))
if (batchSize == 0) {
new AutoBatchedPickler(cleaned)
} else {
cleaned.map(pickle.dumps(_))
val pickle = new Pickler
cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched)))
}
}
}

/**
* Convert an RDD of serialized Python tuple (K, V) to RDD[(K, V)].
*/
def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batchSerialized: Boolean): RDD[(K, V)] = {
def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batched: Boolean): RDD[(K, V)] = {
def isPair(obj: Any): Boolean = {
Option(obj.getClass.getComponentType).map(!_.isPrimitive).getOrElse(false) &&
Option(obj.getClass.getComponentType).exists(!_.isPrimitive) &&
obj.asInstanceOf[Array[_]].length == 2
}
pyRDD.mapPartitions { iter =>
initialize()
val unpickle = new Unpickler
val unpickled =
if (batchSerialized) {
iter.flatMap { batch =>
unpickle.loads(batch) match {
case objs: java.util.List[_] => collectionAsScalaIterable(objs)
case other => throw new SparkException(
s"Unexpected type ${other.getClass.getName} for batch serialized Python RDD")
}
}
} else {
iter.map(unpickle.loads(_))
}
unpickled.map {
case obj if isPair(obj) =>
// we only accept (K, V)
val arr = obj.asInstanceOf[Array[_]]
(arr.head.asInstanceOf[K], arr.last.asInstanceOf[V])
case other => throw new SparkException(
s"RDD element of type ${other.getClass.getName} cannot be used")
}

val rdd = pythonToJava(pyRDD, batched).rdd
rdd.first match {
case obj if isPair(obj) =>
// we only accept (K, V)
case other => throw new SparkException(
s"RDD element of type ${other.getClass.getName} cannot be used")
}
rdd.map { obj =>
val arr = obj.asInstanceOf[Array[_]]
(arr.head.asInstanceOf[K], arr.last.asInstanceOf[V])
}
}

}

Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,7 @@ private[spark] object SerDe extends Serializable {
def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = {
jRDD.rdd.mapPartitions { iter =>
initialize() // let it called in executor
new PythonRDD.AutoBatchedPickler(iter)
new SerDeUtil.AutoBatchedPickler(iter)
}
}

Expand Down
16 changes: 5 additions & 11 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
self._conf = conf or SparkConf(_jvm=self._jvm)
self._batchSize = batchSize # -1 represents an unlimited batch size
self._unbatched_serializer = serializer
if batchSize == 1:
self.serializer = self._unbatched_serializer
elif batchSize == 0:
if batchSize == 0:
self.serializer = AutoBatchedSerializer(self._unbatched_serializer)
else:
self.serializer = BatchedSerializer(self._unbatched_serializer,
Expand Down Expand Up @@ -305,12 +303,8 @@ def parallelize(self, c, numSlices=None):
# Make sure we distribute data evenly if it's smaller than self.batchSize
if "__len__" not in dir(c):
c = list(c) # Make it a list so we can compute its length
batchSize = min(len(c) // numSlices, self._batchSize)
if batchSize > 1:
serializer = BatchedSerializer(self._unbatched_serializer,
batchSize)
else:
serializer = self._unbatched_serializer
batchSize = max(1, min(len(c) // numSlices, self._batchSize))
serializer = BatchedSerializer(self._unbatched_serializer, batchSize)
serializer.dump_stream(c, tempFile)
tempFile.close()
readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
Expand Down Expand Up @@ -431,7 +425,7 @@ def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None,
"""
minSplits = minSplits or min(self.defaultParallelism, 2)
batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input)
ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer()
ser = BatchedSerializer(PickleSerializer())
jrdd = self._jvm.PythonRDD.sequenceFile(self._jsc, path, keyClass, valueClass,
keyConverter, valueConverter, minSplits, batchSize)
return RDD(jrdd, self, ser)
Expand Down Expand Up @@ -836,7 +830,7 @@ def _test():
import doctest
import tempfile
globs = globals().copy()
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
globs['sc'] = SparkContext('local[4]', 'PythonTest')
globs['tempdir'] = tempfile.mkdtemp()
atexit.register(lambda: shutil.rmtree(globs['tempdir']))
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/mllib/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _java2py(sc, r):

if clsName == 'JavaRDD':
jrdd = sc._jvm.SerDe.javaToPython(r)
return RDD(jrdd, sc, AutoBatchedSerializer(PickleSerializer()))
return RDD(jrdd, sc)

elif isinstance(r, (JavaArray, JavaList)) or clsName in _picklable_classes:
r = sc._jvm.SerDe.dumps(r)
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/mllib/recommendation.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _test():
import doctest
import pyspark.mllib.recommendation
globs = pyspark.mllib.recommendation.__dict__.copy()
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
globs['sc'] = SparkContext('local[4]', 'PythonTest')
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
if failure_count:
Expand Down
1 change: 0 additions & 1 deletion python/pyspark/mllib/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ class LinearRegressionModel(LinearRegressionModelBase):

# train_func should take two parameters, namely data and initial_weights, and
# return the result of a call to the appropriate JVM stub.
# _regression_train_wrapper is responsible for setup and error checking.
def _regression_train_wrapper(train_func, modelClass, data, initial_weights):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this was an unintentional change?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mind fixing this when you resolve the merge conflict?

initial_weights = initial_weights or [0.0] * len(data.first().features)
weights, intercept = train_func(_to_java_object_rdd(data, cache=True),
Expand Down
Loading