Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ private[python] object Converter extends Logging {
* Other objects are passed through without conversion.
*/
private[python] class WritableToJavaConverter(
conf: Broadcast[SerializableWritable[Configuration]],
batchSize: Int) extends Converter[Any, Any] {
conf: Broadcast[SerializableWritable[Configuration]]) extends Converter[Any, Any] {

/**
* Converts a [[org.apache.hadoop.io.Writable]] to the underlying primitive, String or
Expand Down Expand Up @@ -94,8 +93,7 @@ private[python] class WritableToJavaConverter(
map.put(convertWritable(k), convertWritable(v))
}
map
case w: Writable =>
if (batchSize > 1) WritableUtils.clone(w, conf.value.value) else w
case w: Writable => WritableUtils.clone(w, conf.value.value)
case other => other
}
}
Expand Down
110 changes: 5 additions & 105 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@ import java.net._
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}

import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.language.existentials

import com.google.common.base.Charsets.UTF_8
import net.razorvine.pickle.{Pickler, Unpickler}

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.compress.CompressionCodec
Expand Down Expand Up @@ -442,7 +440,7 @@ private[spark] object PythonRDD extends Logging {
val rdd = sc.sc.sequenceFile[K, V](path, kc, vc, minSplits)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(sc.hadoopConfiguration()))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
new WritableToJavaConverter(confBroadcasted, batchSize))
new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}

Expand All @@ -468,7 +466,7 @@ private[spark] object PythonRDD extends Logging {
Some(path), inputFormatClass, keyClass, valueClass, mergedConf)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
new WritableToJavaConverter(confBroadcasted, batchSize))
new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}

Expand All @@ -494,7 +492,7 @@ private[spark] object PythonRDD extends Logging {
None, inputFormatClass, keyClass, valueClass, conf)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
new WritableToJavaConverter(confBroadcasted, batchSize))
new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}

Expand Down Expand Up @@ -537,7 +535,7 @@ private[spark] object PythonRDD extends Logging {
Some(path), inputFormatClass, keyClass, valueClass, mergedConf)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
new WritableToJavaConverter(confBroadcasted, batchSize))
new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}

Expand All @@ -563,7 +561,7 @@ private[spark] object PythonRDD extends Logging {
None, inputFormatClass, keyClass, valueClass, conf)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
new WritableToJavaConverter(confBroadcasted, batchSize))
new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}

Expand Down Expand Up @@ -746,104 +744,6 @@ private[spark] object PythonRDD extends Logging {
converted.saveAsHadoopDataset(new JobConf(conf))
}
}


/**
* Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions).
*/
@deprecated("PySpark does not use it anymore", "1.1")
def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {
pyRDD.rdd.mapPartitions { iter =>
val unpickle = new Unpickler
SerDeUtil.initialize()
iter.flatMap { row =>
unpickle.loads(row) match {
// in case of objects are pickled in batch mode
case objs: JArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap)
// not in batch mode
case obj: JMap[String @unchecked, _] => Seq(obj.toMap)
}
}
}
}

/**
* Convert an RDD of serialized Python tuple to Array (no recursive conversions).
* It is only used by pyspark.sql.
*/
def pythonToJavaArray(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Array[_]] = {

def toArray(obj: Any): Array[_] = {
obj match {
case objs: JArrayList[_] =>
objs.toArray
case obj if obj.getClass.isArray =>
obj.asInstanceOf[Array[_]].toArray
}
}

pyRDD.rdd.mapPartitions { iter =>
val unpickle = new Unpickler
iter.flatMap { row =>
val obj = unpickle.loads(row)
if (batched) {
obj.asInstanceOf[JArrayList[_]].map(toArray)
} else {
Seq(toArray(obj))
}
}
}.toJavaRDD()
}

private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] {
private val pickle = new Pickler()
private var batch = 1
private val buffer = new mutable.ArrayBuffer[Any]

override def hasNext(): Boolean = iter.hasNext

override def next(): Array[Byte] = {
while (iter.hasNext && buffer.length < batch) {
buffer += iter.next()
}
val bytes = pickle.dumps(buffer.toArray)
val size = bytes.length
// let 1M < size < 10M
if (size < 1024 * 1024) {
batch *= 2
} else if (size > 1024 * 1024 * 10 && batch > 1) {
batch /= 2
}
buffer.clear()
bytes
}
}

/**
* Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
* PySpark.
*/
def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = {
jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) }
}

/**
* Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark.
*/
def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = {
pyRDD.rdd.mapPartitions { iter =>
SerDeUtil.initialize()
val unpickle = new Unpickler
iter.flatMap { row =>
val obj = unpickle.loads(row)
if (batched) {
obj.asInstanceOf[JArrayList[_]].asScala
} else {
Seq(obj)
}
}
}.toJavaRDD()
}
}

private
Expand Down
121 changes: 90 additions & 31 deletions core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@
package org.apache.spark.api.python

import java.nio.ByteOrder
import java.util.{ArrayList => JArrayList}

import org.apache.spark.api.java.JavaRDD

import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.Failure
import scala.util.Try

Expand Down Expand Up @@ -89,6 +94,73 @@ private[spark] object SerDeUtil extends Logging {
}
initialize()


/**
* Convert an RDD of Java objects to Array (no recursive conversions).
* It is only used by pyspark.sql.
*/
def toJavaArray(jrdd: JavaRDD[Any]): JavaRDD[Array[_]] = {
jrdd.rdd.map {
case objs: JArrayList[_] =>
objs.toArray
case obj if obj.getClass.isArray =>
obj.asInstanceOf[Array[_]].toArray
}.toJavaRDD()
}

/**
* Choose batch size based on size of objects
*/
private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] {
private val pickle = new Pickler()
private var batch = 1
private val buffer = new mutable.ArrayBuffer[Any]

override def hasNext: Boolean = iter.hasNext

override def next(): Array[Byte] = {
while (iter.hasNext && buffer.length < batch) {
buffer += iter.next()
}
val bytes = pickle.dumps(buffer.toArray)
val size = bytes.length
// let 1M < size < 10M
if (size < 1024 * 1024) {
batch *= 2
} else if (size > 1024 * 1024 * 10 && batch > 1) {
batch /= 2
}
buffer.clear()
bytes
}
}

/**
* Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
* PySpark.
*/
private[spark] def javaToPython(jRDD: JavaRDD[_]): JavaRDD[Array[Byte]] = {
jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) }
}

/**
* Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark.
*/
def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = {
pyRDD.rdd.mapPartitions { iter =>
initialize()
val unpickle = new Unpickler
iter.flatMap { row =>
val obj = unpickle.loads(row)
if (batched) {
obj.asInstanceOf[JArrayList[_]].asScala
} else {
Seq(obj)
}
}
}.toJavaRDD()
}

private def checkPickle(t: (Any, Any)): (Boolean, Boolean) = {
val pickle = new Pickler
val kt = Try {
Expand Down Expand Up @@ -128,54 +200,41 @@ private[spark] object SerDeUtil extends Logging {
*/
def pairRDDToPython(rdd: RDD[(Any, Any)], batchSize: Int): RDD[Array[Byte]] = {
val (keyFailed, valueFailed) = checkPickle(rdd.first())

rdd.mapPartitions { iter =>
val pickle = new Pickler
val cleaned = iter.map { case (k, v) =>
val key = if (keyFailed) k.toString else k
val value = if (valueFailed) v.toString else v
Array[Any](key, value)
}
if (batchSize > 1) {
cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched)))
if (batchSize == 0) {
new AutoBatchedPickler(cleaned)
} else {
cleaned.map(pickle.dumps(_))
val pickle = new Pickler
cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched)))
}
}
}

/**
* Convert an RDD of serialized Python tuple (K, V) to RDD[(K, V)].
*/
def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batchSerialized: Boolean): RDD[(K, V)] = {
def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batched: Boolean): RDD[(K, V)] = {
def isPair(obj: Any): Boolean = {
Option(obj.getClass.getComponentType).map(!_.isPrimitive).getOrElse(false) &&
Option(obj.getClass.getComponentType).exists(!_.isPrimitive) &&
obj.asInstanceOf[Array[_]].length == 2
}
pyRDD.mapPartitions { iter =>
initialize()
val unpickle = new Unpickler
val unpickled =
if (batchSerialized) {
iter.flatMap { batch =>
unpickle.loads(batch) match {
case objs: java.util.List[_] => collectionAsScalaIterable(objs)
case other => throw new SparkException(
s"Unexpected type ${other.getClass.getName} for batch serialized Python RDD")
}
}
} else {
iter.map(unpickle.loads(_))
}
unpickled.map {
case obj if isPair(obj) =>
// we only accept (K, V)
val arr = obj.asInstanceOf[Array[_]]
(arr.head.asInstanceOf[K], arr.last.asInstanceOf[V])
case other => throw new SparkException(
s"RDD element of type ${other.getClass.getName} cannot be used")
}

val rdd = pythonToJava(pyRDD, batched).rdd
rdd.first match {
case obj if isPair(obj) =>
// we only accept (K, V)
case other => throw new SparkException(
s"RDD element of type ${other.getClass.getName} cannot be used")
}
rdd.map { obj =>
val arr = obj.asInstanceOf[Array[_]]
(arr.head.asInstanceOf[K], arr.last.asInstanceOf[V])
}
}

}

Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,11 @@ object WriteInputFormatTestDataGenerator {

// Create test data for arbitrary custom writable TestWritable
val testClass = Seq(
("1", TestWritable("test1", 123, 54.0)),
("2", TestWritable("test2", 456, 8762.3)),
("1", TestWritable("test3", 123, 423.1)),
("3", TestWritable("test56", 456, 423.5)),
("2", TestWritable("test2", 123, 5435.2))
("1", TestWritable("test1", 1, 1.0)),
("2", TestWritable("test2", 2, 2.3)),
("3", TestWritable("test3", 3, 3.1)),
("5", TestWritable("test56", 5, 5.5)),
("4", TestWritable("test4", 4, 4.2))
)
val rdd = sc.parallelize(testClass, numSlices = 2).map{ case (k, v) => (new Text(k), v) }
rdd.saveAsNewAPIHadoopFile(classPath,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,7 @@ private[spark] object SerDe extends Serializable {
def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = {
jRDD.rdd.mapPartitions { iter =>
initialize() // let it called in executor
new PythonRDD.AutoBatchedPickler(iter)
new SerDeUtil.AutoBatchedPickler(iter)
}
}

Expand Down
Loading