From 610749505b298587bdd70e1c96202fbf8485e2fe Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 7 Feb 2016 20:50:24 +0800 Subject: [PATCH 01/17] python dataset --- .../apache/spark/api/python/PythonRDD.scala | 8 ++ python/pyspark/sql/dataframe.py | 69 ++++++++++++++++ python/pyspark/sql/tests.py | 11 +++ .../sql/catalyst/plans/logical/object.scala | 19 ++++- .../org/apache/spark/sql/DataFrame.scala | 9 ++- .../spark/sql/execution/SparkStrategies.scala | 2 + .../apache/spark/sql/execution/objects.scala | 79 ++++++++++++++++++- 7 files changed, 193 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index f12e2dfafa19..d779ce76d434 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -71,6 +71,14 @@ private[spark] class PythonRDD( } } +private[spark] case class PythonFunction( + command: Array[Byte], + envVars: JMap[String, String], + pythonIncludes: JList[String], + pythonExec: String, + pythonVer: String, + broadcastVars: JList[Broadcast[PythonBroadcast]], + accumulator: Accumulator[JList[Array[Byte]]]) /** * A helper class to run Python UDFs in Spark. diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 83b034fe7743..d979281732a3 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -278,6 +278,18 @@ def map(self, f): """ return self.rdd.map(f) + @ignore_unicode_prefix + @since(2.0) + def mapPartitions2(self, func): + """ TODO """ + return PipelinedDataFrame(self, func) + + @ignore_unicode_prefix + @since(2.0) + def applySchema(self, schema): + """ TODO """ + return PipelinedDataFrame(self, lambda iterator: map(schema.toInternal, iterator), schema) + @ignore_unicode_prefix @since(1.3) def flatMap(self, f): @@ -1354,6 +1366,63 @@ def toPandas(self): drop_duplicates = dropDuplicates +class PipelinedDataFrame(DataFrame): + + """ TODO """ + + schemaOfPickled = StructType().add("binary", BinaryType(), False, {"pickled": True}) + + def __init__(self, prev, func, schema=None): + if schema is None: + self._schema = self.schemaOfPickled + else: + self._schema = schema + + self.func = func + self._prev_jdf = prev._jdf # maintain the pipeline + self.is_cached = False + self.sql_ctx = prev.sql_ctx + self._sc = self.sql_ctx and self.sql_ctx._sc + self._jdf_val = None + self._lazy_rdd = None + + if not isinstance(prev, PipelinedDataFrame) or not prev.is_cached: + # This transformation is the first in its stage: + self.prev_func = None + elif prev._schema is not self.schemaOfPickled and self._schema is not self.schemaOfPickled: + # The previous operation is also adding schema, override it. + self.prev_func = prev.prev_func + else: + self.prev_func = _pipeline_func(prev.prev_func, prev.func) + + @property + def _jdf(self): + if self._jdf_val is None: + final_func = _pipeline_func(self.prev_func, self.func) + self._jdf_val = self._prev_jdf.pythonMapPartitions( + _wrap_function(final_func, self._sc), self._schema.json()) + + return self._jdf_val + + +def _wrap_function(f, sc): + from pyspark.rdd import _prepare_for_python_RDD + from pyspark.serializers import AutoBatchedSerializer + + ser = AutoBatchedSerializer(PickleSerializer()) + command = (lambda _, iterator: f(iterator), None, ser, ser) + pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command) + return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec, + sc.pythonVer, broadcast_vars, sc._javaAccumulator) + + +def _pipeline_func(prev_func, next_func): + if prev_func is None: + return next_func + else: + return lambda iterator: next_func(prev_func(iterator)) + + def _to_scala_map(sc, jm): """ Convert a dict into a JVM Map. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index e30aa0a79692..518e24b9d3f0 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1153,6 +1153,17 @@ def test_functions_broadcast(self): # planner should not crash without a join broadcast(df1)._jdf.queryExecution().executedPlan() + def test_dataset(self): + ds = self.sqlCtx.createDataFrame([(1, "1"), (2, "2")], ("key", "value")) + process = lambda row: {"key": 33, "value": "abc"} + ds2 = ds.mapPartitions2(lambda iterator: map(process, iterator)) + + schema = StructType().add("key", IntegerType()).add("value", StringType()) + ds3 = ds2.applySchema(schema) + result = ds3.select("key").collect() + self.assertEqual(result[0][0], 33) + self.assertEqual(result[1][0], 33) + class HiveContextSQLTests(ReusedPySparkTestCase): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 3f97662957b8..f2c4033988af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.api.python.PythonFunction import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ @@ -91,6 +92,22 @@ case class MapPartitions( override def deserializers: Seq[(Expression, Seq[Attribute])] = Seq(deserializer -> child.output) } +object PythonMapPartitions { + def apply(func: PythonFunction, schema: StructType, child: LogicalPlan): PythonMapPartitions = { + PythonMapPartitions(func, schema, schema.toAttributes, child) + } +} + +case class PythonMapPartitions( + func: PythonFunction, + outputSchema: StructType, + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + override lazy val schema: StructType = outputSchema + + override def expressions: Seq[Expression] = Nil +} + /** Factory for constructing new `AppendColumn` nodes. */ object AppendColumns { def apply[T : Encoder, U : Encoder]( @@ -208,8 +225,6 @@ case class CoGroup( left: LogicalPlan, right: LogicalPlan) extends BinaryNode with ObjectOperator { - override def producedAttributes: AttributeSet = outputSet - override def deserializers: Seq[(Expression, Seq[Attribute])] = // The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to resolve // the `keyDeserializer` based on either of them, here we pick the left one. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 76c09a285dc4..225c8cb47d37 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -27,7 +27,7 @@ import com.fasterxml.jackson.core.JsonFactory import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD -import org.apache.spark.api.python.PythonRDD +import org.apache.spark.api.python.{PythonFunction, PythonRDD} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis._ @@ -1761,6 +1761,13 @@ class DataFrame private[sql]( } } + protected[sql] def pythonMapPartitions( + func: PythonFunction, + schemaJson: String): DataFrame = withPlan { + val schema = DataType.fromJson(schemaJson).asInstanceOf[StructType] + PythonMapPartitions(func, schema, logicalPlan) + } + /** * Wrap a DataFrame action to track all Spark jobs in the body so that we can connect them with * an execution. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 73fd22b38e1d..01ca1b520f98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -371,6 +371,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil case e @ python.EvaluatePython(udf, child, _) => python.BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil + case logical.PythonMapPartitions(func, schema, output, child) => + execution.PythonMapPartitions(func, schema, output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil case BroadcastHint(child) => planLater(child) :: Nil case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 582dda8603f4..8495e3f7aac4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -17,12 +17,19 @@ package org.apache.spark.sql.execution +import scala.collection.JavaConverters._ + +import net.razorvine.pickle.{Pickler, Unpickler} + +import org.apache.spark.TaskContext +import org.apache.spark.api.python.{PythonFunction, PythonRunner} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection, GenerateUnsafeRowJoiner} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.types.ObjectType +import org.apache.spark.sql.execution.python.EvaluatePython +import org.apache.spark.sql.types.{BinaryType, ObjectType, StructType} /** * Helper functions for physical operators that work with user defined objects. @@ -67,6 +74,76 @@ case class MapPartitions( } } +case class PythonMapPartitions( + func: PythonFunction, + outputSchema: StructType, + output: Seq[Attribute], + child: SparkPlan) extends UnaryNode { + + override def producedAttributes: AttributeSet = outputSet + + override lazy val schema: StructType = outputSchema + + private def isPickled(schema: StructType): Boolean = { + schema.length == 1 && schema.head.dataType == BinaryType && + schema.head.metadata.contains("pickled") + } + + override protected def doExecute(): RDD[InternalRow] = { + val inputRDD = child.execute().map(_.copy()) + val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) + val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) + val childSchema = child.schema + val childIsPickled = isPickled(childSchema) + val outputIsPickled = isPickled(outputSchema) + + inputRDD.mapPartitions { iter => + val inputIterator = if (childIsPickled) { + iter.map(_.getBinary(0)) + } else { + EvaluatePython.registerPicklers() // register pickler for Row + + val pickle = new Pickler + + // Input iterator to Python: input rows are grouped so we send them in batches to Python. + // For each row, add it to the queue. + iter.grouped(100).map { inputRows => + val toBePickled = inputRows.map { row => + EvaluatePython.toJava(row, childSchema) + }.toArray + pickle.dumps(toBePickled) + } + } + + val context = TaskContext.get() + + // Output iterator for results from Python. + val outputIterator = + new PythonRunner( + func.command, + func.envVars, + func.pythonIncludes, + func.pythonExec, + func.pythonVer, + func.broadcastVars, + func.accumulator, + bufferSize, + reuseWorker + ).compute(inputIterator, context.partitionId(), context) + + if (outputIsPickled) { + outputIterator.map(bytes => InternalRow(bytes)) + } else { + val unpickle = new Unpickler + outputIterator.flatMap { pickedResult => + val unpickledBatch = unpickle.loads(pickedResult) + unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala + }.map(result => EvaluatePython.fromJava(result, outputSchema).asInstanceOf[InternalRow]) + } + } + } +} + /** * Applies the given function to each input row, appending the encoded result at the end of the row. */ From 15fd836ddbb23a31c5f4860128e659f2cd557f1a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 15 Feb 2016 14:28:10 +0800 Subject: [PATCH 02/17] code cleanup --- python/pyspark/sql/dataframe.py | 72 ++++++++++++++++++--------------- python/pyspark/sql/tests.py | 21 ++++++++-- 2 files changed, 56 insertions(+), 37 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index d979281732a3..d9be39a507c3 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -284,12 +284,6 @@ def mapPartitions2(self, func): """ TODO """ return PipelinedDataFrame(self, func) - @ignore_unicode_prefix - @since(2.0) - def applySchema(self, schema): - """ TODO """ - return PipelinedDataFrame(self, lambda iterator: map(schema.toInternal, iterator), schema) - @ignore_unicode_prefix @since(1.3) def flatMap(self, f): @@ -1370,50 +1364,62 @@ class PipelinedDataFrame(DataFrame): """ TODO """ - schemaOfPickled = StructType().add("binary", BinaryType(), False, {"pickled": True}) - - def __init__(self, prev, func, schema=None): - if schema is None: - self._schema = self.schemaOfPickled - else: - self._schema = schema - - self.func = func - self._prev_jdf = prev._jdf # maintain the pipeline + def __init__(self, prev, func, output_schema=None): + self.output_schema = output_schema + self._schema = None self.is_cached = False self.sql_ctx = prev.sql_ctx self._sc = self.sql_ctx and self.sql_ctx._sc self._jdf_val = None self._lazy_rdd = None - if not isinstance(prev, PipelinedDataFrame) or not prev.is_cached: + if output_schema is not None: + # This transformation is applying schema, just copy member variables from prev. + self.func = func + self._prev_jdf = prev._prev_jdf + elif not isinstance(prev, PipelinedDataFrame) or not prev.is_cached: # This transformation is the first in its stage: - self.prev_func = None - elif prev._schema is not self.schemaOfPickled and self._schema is not self.schemaOfPickled: - # The previous operation is also adding schema, override it. - self.prev_func = prev.prev_func + self.func = func + self._prev_jdf = prev._jdf else: - self.prev_func = _pipeline_func(prev.prev_func, prev.func) + self.func = _pipeline_func(prev.func, func) + self._prev_jdf = prev._prev_jdf # maintain the pipeline + + def applySchema(self, schema): + return PipelinedDataFrame(self, self.func, schema) @property def _jdf(self): if self._jdf_val is None: - final_func = _pipeline_func(self.prev_func, self.func) - self._jdf_val = self._prev_jdf.pythonMapPartitions( - _wrap_function(final_func, self._sc), self._schema.json()) + if self.output_schema is None: + schema = StructType().add("binary", BinaryType(), False, {"pickled": True}) + final_func = self.func + elif isinstance(self.output_schema, StructType): + schema = self.output_schema + to_row = lambda iterator: map(schema.toInternal, iterator) + final_func = _pipeline_func(self.func, to_row) + else: + data_type = self.output_schema + schema = StructType().add("value", data_type) + converter = lambda obj: (data_type.toInternal(obj), ) + to_row = lambda iterator: map(converter, iterator) + final_func = _pipeline_func(self.func, to_row) + + self._jdf_val = self._prev_jdf.pythonMapPartitions(self._wrap_function(final_func), schema.json()) return self._jdf_val -def _wrap_function(f, sc): - from pyspark.rdd import _prepare_for_python_RDD - from pyspark.serializers import AutoBatchedSerializer + def _wrap_function(self, f): + from pyspark.rdd import _prepare_for_python_RDD + from pyspark.serializers import AutoBatchedSerializer - ser = AutoBatchedSerializer(PickleSerializer()) - command = (lambda _, iterator: f(iterator), None, ser, ser) - pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command) - return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec, - sc.pythonVer, broadcast_vars, sc._javaAccumulator) + ser = AutoBatchedSerializer(PickleSerializer()) + command = (lambda _, iterator: f(iterator), None, ser, ser) + sc = self._sc + pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command) + return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec, + sc.pythonVer, broadcast_vars, sc._javaAccumulator) def _pipeline_func(prev_func, next_func): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 518e24b9d3f0..10bb0a3929e5 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1155,14 +1155,27 @@ def test_functions_broadcast(self): def test_dataset(self): ds = self.sqlCtx.createDataFrame([(1, "1"), (2, "2")], ("key", "value")) - process = lambda row: {"key": 33, "value": "abc"} - ds2 = ds.mapPartitions2(lambda iterator: map(process, iterator)) + func = lambda row: {"key": row.key + 1, "value": row.value} # convert row to python dict + ds2 = ds.mapPartitions2(lambda iterator: map(func, iterator)) schema = StructType().add("key", IntegerType()).add("value", StringType()) ds3 = ds2.applySchema(schema) result = ds3.select("key").collect() - self.assertEqual(result[0][0], 33) - self.assertEqual(result[1][0], 33) + self.assertEqual(result[0][0], 2) + self.assertEqual(result[1][0], 3) + + schema = StructType().add("value", StringType()) # use a different but compatible schema + ds3 = ds2.applySchema(schema) + result = ds3.collect() + self.assertEqual(result[0][0], "1") + self.assertEqual(result[1][0], "2") + + func = lambda row: row.key * 3 + ds2 = ds.mapPartitions2(lambda iterator: map(func, iterator)) + ds3 = ds2.applySchema(IntegerType()) # use a flat schema + result = ds3.collect() + self.assertEqual(result[0][0], 3) + self.assertEqual(result[1][0], 6) class HiveContextSQLTests(ReusedPySparkTestCase): From a0a0dd63cf22048625622cc4dc5c425ee098cadb Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 15 Feb 2016 14:57:44 +0800 Subject: [PATCH 03/17] scala side cleanup --- .../spark/sql/catalyst/plans/logical/object.scala | 9 --------- .../scala/org/apache/spark/sql/DataFrame.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 4 ++-- .../org/apache/spark/sql/execution/objects.scala | 14 +++++--------- 4 files changed, 8 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index f2c4033988af..791cc9367d18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -92,19 +92,10 @@ case class MapPartitions( override def deserializers: Seq[(Expression, Seq[Attribute])] = Seq(deserializer -> child.output) } -object PythonMapPartitions { - def apply(func: PythonFunction, schema: StructType, child: LogicalPlan): PythonMapPartitions = { - PythonMapPartitions(func, schema, schema.toAttributes, child) - } -} - case class PythonMapPartitions( func: PythonFunction, - outputSchema: StructType, output: Seq[Attribute], child: LogicalPlan) extends UnaryNode { - override lazy val schema: StructType = outputSchema - override def expressions: Seq[Expression] = Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 225c8cb47d37..b4676067eb6c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1765,7 +1765,7 @@ class DataFrame private[sql]( func: PythonFunction, schemaJson: String): DataFrame = withPlan { val schema = DataType.fromJson(schemaJson).asInstanceOf[StructType] - PythonMapPartitions(func, schema, logicalPlan) + PythonMapPartitions(func, schema.toAttributes, logicalPlan) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 01ca1b520f98..9dc5550da452 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -371,8 +371,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil case e @ python.EvaluatePython(udf, child, _) => python.BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil - case logical.PythonMapPartitions(func, schema, output, child) => - execution.PythonMapPartitions(func, schema, output, planLater(child)) :: Nil + case logical.PythonMapPartitions(func, output, child) => + execution.PythonMapPartitions(func, output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil case BroadcastHint(child) => planLater(child) :: Nil case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 8495e3f7aac4..433bc40c5b1f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -76,13 +76,10 @@ case class MapPartitions( case class PythonMapPartitions( func: PythonFunction, - outputSchema: StructType, output: Seq[Attribute], child: SparkPlan) extends UnaryNode { - override def producedAttributes: AttributeSet = outputSet - - override lazy val schema: StructType = outputSchema + override def expressions: Seq[Expression] = Nil private def isPickled(schema: StructType): Boolean = { schema.length == 1 && schema.head.dataType == BinaryType && @@ -93,9 +90,8 @@ case class PythonMapPartitions( val inputRDD = child.execute().map(_.copy()) val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) - val childSchema = child.schema - val childIsPickled = isPickled(childSchema) - val outputIsPickled = isPickled(outputSchema) + val childIsPickled = isPickled(child.schema) + val outputIsPickled = isPickled(schema) inputRDD.mapPartitions { iter => val inputIterator = if (childIsPickled) { @@ -109,7 +105,7 @@ case class PythonMapPartitions( // For each row, add it to the queue. iter.grouped(100).map { inputRows => val toBePickled = inputRows.map { row => - EvaluatePython.toJava(row, childSchema) + EvaluatePython.toJava(row, child.schema) }.toArray pickle.dumps(toBePickled) } @@ -138,7 +134,7 @@ case class PythonMapPartitions( outputIterator.flatMap { pickedResult => val unpickledBatch = unpickle.loads(pickedResult) unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala - }.map(result => EvaluatePython.fromJava(result, outputSchema).asInstanceOf[InternalRow]) + }.map(result => EvaluatePython.fromJava(result, schema).asInstanceOf[InternalRow]) } } } From 6c26daa92b32a57dcae028415cc30c6968fce9ee Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 15 Feb 2016 15:02:33 +0800 Subject: [PATCH 04/17] fix style --- python/pyspark/sql/dataframe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index d9be39a507c3..27dce00884ee 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1405,11 +1405,11 @@ def _jdf(self): to_row = lambda iterator: map(converter, iterator) final_func = _pipeline_func(self.func, to_row) - self._jdf_val = self._prev_jdf.pythonMapPartitions(self._wrap_function(final_func), schema.json()) + wrapped_func = self._wrap_function(final_func) + self._jdf_val = self._prev_jdf.pythonMapPartitions(wrapped_func, schema.json()) return self._jdf_val - def _wrap_function(self, f): from pyspark.rdd import _prepare_for_python_RDD from pyspark.serializers import AutoBatchedSerializer From d96f103d2fc9fc36378b00bbfe5541fd95db3c2f Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 15 Feb 2016 16:31:11 +0800 Subject: [PATCH 05/17] produce unsafe rows --- .../org/apache/spark/sql/execution/objects.scala | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 433bc40c5b1f..f6d420908fbe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -127,14 +127,22 @@ case class PythonMapPartitions( reuseWorker ).compute(inputIterator, context.partitionId(), context) + val resultProj = UnsafeProjection.create(output, output) + if (outputIsPickled) { - outputIterator.map(bytes => InternalRow(bytes)) + val row = new GenericMutableRow(1) + outputIterator.map { bytes => + row(0) = bytes + resultProj(row) + } } else { val unpickle = new Unpickler outputIterator.flatMap { pickedResult => val unpickledBatch = unpickle.loads(pickedResult) unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala - }.map(result => EvaluatePython.fromJava(result, schema).asInstanceOf[InternalRow]) + }.map { result => + resultProj(EvaluatePython.fromJava(result, schema).asInstanceOf[InternalRow]) + } } } } From 4dfe604bcf707500b0765362e3c3ce55d5e67412 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 15 Feb 2016 21:28:08 +0800 Subject: [PATCH 06/17] infer schema --- python/pyspark/sql/dataframe.py | 16 +++-- python/pyspark/sql/tests.py | 13 ++++- .../apache/spark/sql/execution/objects.scala | 58 ++++++------------- 3 files changed, 40 insertions(+), 47 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 27dce00884ee..9e09e56f72f6 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1374,7 +1374,7 @@ def __init__(self, prev, func, output_schema=None): self._lazy_rdd = None if output_schema is not None: - # This transformation is applying schema, just copy member variables from prev. + # This transformation is adding schema, just copy member variables from prev. self.func = func self._prev_jdf = prev._prev_jdf elif not isinstance(prev, PipelinedDataFrame) or not prev.is_cached: @@ -1385,16 +1385,22 @@ def __init__(self, prev, func, output_schema=None): self.func = _pipeline_func(prev.func, func) self._prev_jdf = prev._prev_jdf # maintain the pipeline - def applySchema(self, schema): + def schema(self, schema): return PipelinedDataFrame(self, self.func, schema) @property def _jdf(self): + from pyspark.sql.types import _infer_type, _merge_type + if self._jdf_val is None: if self.output_schema is None: - schema = StructType().add("binary", BinaryType(), False, {"pickled": True}) - final_func = self.func - elif isinstance(self.output_schema, StructType): + # If no schema is specified, infer it from the whole data set. + jrdd = self._prev_jdf.javaToPython() + rdd = RDD(jrdd, self._sc, BatchedSerializer(PickleSerializer())) + func = self.func # assign to a local varible to avoid referencing self in closure. + self.output_schema = rdd.mapPartitions(func).map(_infer_type).reduce(_merge_type) + + if isinstance(self.output_schema, StructType): schema = self.output_schema to_row = lambda iterator: map(schema.toInternal, iterator) final_func = _pipeline_func(self.func, to_row) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 10bb0a3929e5..b1b4fd6a7fee 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1159,24 +1159,31 @@ def test_dataset(self): func = lambda row: {"key": row.key + 1, "value": row.value} # convert row to python dict ds2 = ds.mapPartitions2(lambda iterator: map(func, iterator)) schema = StructType().add("key", IntegerType()).add("value", StringType()) - ds3 = ds2.applySchema(schema) + ds3 = ds2.schema(schema) result = ds3.select("key").collect() self.assertEqual(result[0][0], 2) self.assertEqual(result[1][0], 3) schema = StructType().add("value", StringType()) # use a different but compatible schema - ds3 = ds2.applySchema(schema) + ds3 = ds2.schema(schema) result = ds3.collect() self.assertEqual(result[0][0], "1") self.assertEqual(result[1][0], "2") func = lambda row: row.key * 3 ds2 = ds.mapPartitions2(lambda iterator: map(func, iterator)) - ds3 = ds2.applySchema(IntegerType()) # use a flat schema + ds3 = ds2.schema(IntegerType()) # use a flat schema result = ds3.collect() self.assertEqual(result[0][0], 3) self.assertEqual(result[1][0], 6) + result = ds2.collect() # schema can be inferred automatically + self.assertEqual(result[0][0], 3) + self.assertEqual(result[1][0], 6) + + # row count should be corrected even no schema is specified. + self.assertEqual(ds2.count(), 2) + class HiveContextSQLTests(ReusedPySparkTestCase): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index f6d420908fbe..049bf9a582c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection, GenerateUnsafeRowJoiner} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.python.EvaluatePython -import org.apache.spark.sql.types.{BinaryType, ObjectType, StructType} +import org.apache.spark.sql.types.ObjectType /** * Helper functions for physical operators that work with user defined objects. @@ -81,34 +81,22 @@ case class PythonMapPartitions( override def expressions: Seq[Expression] = Nil - private def isPickled(schema: StructType): Boolean = { - schema.length == 1 && schema.head.dataType == BinaryType && - schema.head.metadata.contains("pickled") - } - override protected def doExecute(): RDD[InternalRow] = { val inputRDD = child.execute().map(_.copy()) val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) - val childIsPickled = isPickled(child.schema) - val outputIsPickled = isPickled(schema) inputRDD.mapPartitions { iter => - val inputIterator = if (childIsPickled) { - iter.map(_.getBinary(0)) - } else { - EvaluatePython.registerPicklers() // register pickler for Row - - val pickle = new Pickler - - // Input iterator to Python: input rows are grouped so we send them in batches to Python. - // For each row, add it to the queue. - iter.grouped(100).map { inputRows => - val toBePickled = inputRows.map { row => - EvaluatePython.toJava(row, child.schema) - }.toArray - pickle.dumps(toBePickled) - } + EvaluatePython.registerPicklers() // register pickler for Row + val pickle = new Pickler + + // Input iterator to Python: input rows are grouped so we send them in batches to Python. + // For each row, add it to the queue. + val inputIterator = iter.grouped(100).map { inputRows => + val toBePickled = inputRows.map { row => + EvaluatePython.toJava(row, child.schema) + }.toArray + pickle.dumps(toBePickled) } val context = TaskContext.get() @@ -127,22 +115,14 @@ case class PythonMapPartitions( reuseWorker ).compute(inputIterator, context.partitionId(), context) - val resultProj = UnsafeProjection.create(output, output) - - if (outputIsPickled) { - val row = new GenericMutableRow(1) - outputIterator.map { bytes => - row(0) = bytes - resultProj(row) - } - } else { - val unpickle = new Unpickler - outputIterator.flatMap { pickedResult => - val unpickledBatch = unpickle.loads(pickedResult) - unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala - }.map { result => - resultProj(EvaluatePython.fromJava(result, schema).asInstanceOf[InternalRow]) - } + val unpickle = new Unpickler + val toUnsafe = UnsafeProjection.create(output, output) + + outputIterator.flatMap { pickedResult => + val unpickledBatch = unpickle.loads(pickedResult) + unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala + }.map { result => + toUnsafe(EvaluatePython.fromJava(result, schema).asInstanceOf[InternalRow]) } } } From e0ca98f8725a6b808cbabfb80c7a9a18f13150fe Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 17 Feb 2016 23:38:05 +0800 Subject: [PATCH 07/17] aggregate --- python/pyspark/rdd.py | 11 + python/pyspark/sql/dataframe.py | 131 +++++++---- python/pyspark/sql/group.py | 49 +++- python/pyspark/sql/tests.py | 56 +++-- .../sql/catalyst/plans/logical/object.scala | 22 +- .../org/apache/spark/sql/DataFrame.scala | 36 +++ .../spark/sql/GroupedPythonDataset.scala | 91 ++++++++ .../spark/sql/execution/SparkStrategies.scala | 4 + .../apache/spark/sql/execution/objects.scala | 215 ++++++++++++++++-- .../sql/execution/python/EvaluatePython.scala | 11 + 10 files changed, 537 insertions(+), 89 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/GroupedPythonDataset.scala diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index fe2264a63cf3..491abac21ebe 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2330,6 +2330,17 @@ def _prepare_for_python_RDD(sc, command, obj=None): return pickled_command, broadcast_vars, env, includes +def _wrap_function(sc, func, deserializer=None, serializer=None, profiler=None): + if deserializer is None: + deserializer = AutoBatchedSerializer(PickleSerializer()) + if serializer is None: + serializer = AutoBatchedSerializer(PickleSerializer()) + command = (func, profiler, deserializer, serializer) + pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command) + return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec, + sc.pythonVer, broadcast_vars, sc._javaAccumulator) + + class PipelinedRDD(RDD): """ diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 9e09e56f72f6..0863c3ea9964 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -278,12 +278,30 @@ def map(self, f): """ return self.rdd.map(f) + @ignore_unicode_prefix + @since(2.0) + def applySchema(self, schema=None): + """ TODO """ + if schema is None: + from pyspark.sql.types import _infer_type, _merge_type + # If no schema is specified, infer it from the whole data set. + jrdd = self._prev_jdf.javaToPython() + rdd = RDD(jrdd, self._sc, BatchedSerializer(PickleSerializer())) + schema = rdd.mapPartitions(self._func).map(_infer_type).reduce(_merge_type) + return PipelinedDataFrame(self, output_schema=schema) + @ignore_unicode_prefix @since(2.0) def mapPartitions2(self, func): """ TODO """ return PipelinedDataFrame(self, func) + @ignore_unicode_prefix + @since(2.0) + def map2(self, func): + """ TODO """ + return self.mapPartitions2(lambda iterator: map(func, iterator)) + @ignore_unicode_prefix @since(1.3) def flatMap(self, f): @@ -896,10 +914,20 @@ def groupBy(self, *cols): >>> sorted(df.groupBy(['name', df.age]).count().collect()) [Row(name=u'Alice', age=2, count=1), Row(name=u'Bob', age=5, count=1)] """ - jgd = self._jdf.groupBy(self._jcols(*cols)) + jgd = self._jdf.pythonGroupBy(self._jcols(*cols)) from pyspark.sql.group import GroupedData return GroupedData(jgd, self.sql_ctx) + @ignore_unicode_prefix + @since(2.0) + def groupByKey(self, key_func, key_type): + """ TODO """ + f = lambda iterator: map(key_func, iterator) + wraped_func = _wrap_func(self._sc, self._jdf, f, False) + jgd = self._jdf.pythonGroupBy(wraped_func, key_type.json()) + from pyspark.sql.group import GroupedData + return GroupedData(jgd, self.sql_ctx, key_func) + @since(1.4) def rollup(self, *cols): """ @@ -1364,68 +1392,87 @@ class PipelinedDataFrame(DataFrame): """ TODO """ - def __init__(self, prev, func, output_schema=None): - self.output_schema = output_schema - self._schema = None + def __init__(self, prev, func=None, output_schema=None): + from pyspark.sql.group import GroupedData + + if output_schema is None: + self._schema = StructType().add("binary", BinaryType(), False, {"pickled": True}) + else: + self._schema = output_schema + + self._output_schema = output_schema + self._jdf_val = None self.is_cached = False self.sql_ctx = prev.sql_ctx self._sc = self.sql_ctx and self.sql_ctx._sc - self._jdf_val = None self._lazy_rdd = None - if output_schema is not None: - # This transformation is adding schema, just copy member variables from prev. - self.func = func - self._prev_jdf = prev._prev_jdf - elif not isinstance(prev, PipelinedDataFrame) or not prev.is_cached: + if isinstance(prev, GroupedData): + # prev is GroupedData, set the grouped flag to true and use jgd as jdf. + self._grouped = True + self._func = func + self._prev_jdf = prev._jgd + elif not isinstance(prev, PipelinedDataFrame) or prev.is_cached: # This transformation is the first in its stage: - self.func = func + self._func = func self._prev_jdf = prev._jdf + self._grouped = False else: - self.func = _pipeline_func(prev.func, func) - self._prev_jdf = prev._prev_jdf # maintain the pipeline - - def schema(self, schema): - return PipelinedDataFrame(self, self.func, schema) + if func is None: + # This transformation is applying schema, no need to pipeline the function. + self._func = prev._func + else: + self._func = _pipeline_func(prev._func, func) + # maintain the pipeline. + self._prev_jdf = prev._prev_jdf + self._grouped = prev._grouped @property def _jdf(self): - from pyspark.sql.types import _infer_type, _merge_type - if self._jdf_val is None: - if self.output_schema is None: - # If no schema is specified, infer it from the whole data set. - jrdd = self._prev_jdf.javaToPython() - rdd = RDD(jrdd, self._sc, BatchedSerializer(PickleSerializer())) - func = self.func # assign to a local varible to avoid referencing self in closure. - self.output_schema = rdd.mapPartitions(func).map(_infer_type).reduce(_merge_type) - - if isinstance(self.output_schema, StructType): - schema = self.output_schema + if self._output_schema is None: + self._jdf_val = self._create_jdf(self._func) + elif isinstance(self._output_schema, StructType): + schema = self._output_schema to_row = lambda iterator: map(schema.toInternal, iterator) - final_func = _pipeline_func(self.func, to_row) + self._jdf_val = self._create_jdf(_pipeline_func(self._func, to_row), schema) else: - data_type = self.output_schema + data_type = self._output_schema schema = StructType().add("value", data_type) converter = lambda obj: (data_type.toInternal(obj), ) to_row = lambda iterator: map(converter, iterator) - final_func = _pipeline_func(self.func, to_row) - - wrapped_func = self._wrap_function(final_func) - self._jdf_val = self._prev_jdf.pythonMapPartitions(wrapped_func, schema.json()) + self._jdf_val = self._create_jdf(_pipeline_func(self._func, to_row), schema) return self._jdf_val - def _wrap_function(self, f): - from pyspark.rdd import _prepare_for_python_RDD - from pyspark.serializers import AutoBatchedSerializer + def _create_jdf(self, func, schema=None): + wrapped_func = _wrap_func(self._sc, self._prev_jdf, func, schema is None) + if schema is None: + if self._grouped: + return self._prev_jdf.flatMapGroups(wrapped_func) + else: + return self._prev_jdf.pythonMapPartitions(wrapped_func) + else: + schema_string = schema.json() + if self._grouped: + return self._prev_jdf.flatMapGroups(wrapped_func, schema_string) + else: + return self._prev_jdf.pythonMapPartitions(wrapped_func, schema_string) + + +def _wrap_func(sc, jdf, func, output_binary): + if jdf.isPickled(): + deserializer = PickleSerializer() + else: + deserializer = None # the framework will provide a default one + + if output_binary: + serializer = PickleSerializer() + else: + serializer = None # the framework will provide a default one - ser = AutoBatchedSerializer(PickleSerializer()) - command = (lambda _, iterator: f(iterator), None, ser, ser) - sc = self._sc - pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command) - return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec, - sc.pythonVer, broadcast_vars, sc._javaAccumulator) + from pyspark.rdd import _wrap_function + return _wrap_function(sc, lambda _, iterator: func(iterator), deserializer, serializer) def _pipeline_func(prev_func, next_func): diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index ee734cb43928..759a4cd9cb4e 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -18,7 +18,7 @@ from pyspark import since from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import Column, _to_seq, _to_java_column, _create_column_from_literal -from pyspark.sql.dataframe import DataFrame +from pyspark.sql.dataframe import DataFrame, PipelinedDataFrame from pyspark.sql.types import * __all__ = ["GroupedData"] @@ -27,7 +27,7 @@ def dfapi(f): def _api(self): name = f.__name__ - jdf = getattr(self._jdf, name)() + jdf = getattr(self._jgd, name)() return DataFrame(jdf, self.sql_ctx) _api.__name__ = f.__name__ _api.__doc__ = f.__doc__ @@ -37,7 +37,7 @@ def _api(self): def df_varargs_api(f): def _api(self, *args): name = f.__name__ - jdf = getattr(self._jdf, name)(_to_seq(self.sql_ctx._sc, args)) + jdf = getattr(self._jgd, name)(_to_seq(self.sql_ctx._sc, args)) return DataFrame(jdf, self.sql_ctx) _api.__name__ = f.__name__ _api.__doc__ = f.__doc__ @@ -54,9 +54,33 @@ class GroupedData(object): .. versionadded:: 1.3 """ - def __init__(self, jdf, sql_ctx): - self._jdf = jdf + def __init__(self, jgd, sql_ctx, key_func=None): + self._jgd = jgd self.sql_ctx = sql_ctx + if key_func is None: + self.key_func = lambda key: key + else: + self.key_func = key_func + + @ignore_unicode_prefix + @since(2.0) + def flatMapGroups(self, func): + """ TODO """ + import itertools + key_func = self.key_func + + def process(iterator): + first = iterator.next() + key = key_func(first) + return func(key, itertools.chain([first], iterator)) + + return PipelinedDataFrame(self, process) + + @ignore_unicode_prefix + @since(2.0) + def mapGroups(self, func): + """ TODO """ + return self.flatMapGroups(lambda key, values: iter([func(key, values)])) @ignore_unicode_prefix @since(1.3) @@ -83,11 +107,11 @@ def agg(self, *exprs): """ assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): - jdf = self._jdf.agg(exprs[0]) + jdf = self._jgd.agg(exprs[0]) else: # Columns assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" - jdf = self._jdf.agg(exprs[0]._jc, + jdf = self._jgd.agg(exprs[0]._jc, _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]])) return DataFrame(jdf, self.sql_ctx) @@ -187,9 +211,9 @@ def pivot(self, pivot_col, values=None): [Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)] """ if values is None: - jgd = self._jdf.pivot(pivot_col) + jgd = self._jgd.pivot(pivot_col) else: - jgd = self._jdf.pivot(pivot_col, values) + jgd = self._jgd.pivot(pivot_col, values) return GroupedData(jgd, self.sql_ctx) @@ -213,6 +237,13 @@ def _test(): Row(course="dotNET", year=2013, earnings=48000), Row(course="Java", year=2013, earnings=30000)]).toDF() + ds = globs['sqlContext'].createDataFrame([(i, i) for i in range(100)], ("key", "value")) + grouped = ds.groupByKey(lambda row: row.key % 5, IntegerType()) + value_sum = lambda rows: sum(map(lambda row: row.value, rows)) + agged = grouped.mapGroups(lambda key, values: str(key) + ":" + str(value_sum(values))) + result = agged.applySchema(StringType()).collect() + raise ValueError(result[0][0]) + (failure_count, test_count) = doctest.testmod( pyspark.sql.group, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index b1b4fd6a7fee..719cb79f6b61 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1154,35 +1154,49 @@ def test_functions_broadcast(self): broadcast(df1)._jdf.queryExecution().executedPlan() def test_dataset(self): - ds = self.sqlCtx.createDataFrame([(1, "1"), (2, "2")], ("key", "value")) + ds = self.sqlCtx.createDataFrame([(i, str(i)) for i in range(100)], ("key", "value")) - func = lambda row: {"key": row.key + 1, "value": row.value} # convert row to python dict - ds2 = ds.mapPartitions2(lambda iterator: map(func, iterator)) + # convert row to python dict + ds2 = ds.map2(lambda row: {"key": row.key + 1, "value": row.value}) schema = StructType().add("key", IntegerType()).add("value", StringType()) - ds3 = ds2.schema(schema) + ds3 = ds2.applySchema(schema) result = ds3.select("key").collect() - self.assertEqual(result[0][0], 2) - self.assertEqual(result[1][0], 3) + # test first 2 elements + self.assertEqual(result[0][0], 1) + self.assertEqual(result[1][0], 2) + + # use a different but compatible schema + schema = StructType().add("value", StringType()) + result = ds2.applySchema(schema).collect() + self.assertEqual(result[0][0], "0") + self.assertEqual(result[1][0], "1") - schema = StructType().add("value", StringType()) # use a different but compatible schema - ds3 = ds2.schema(schema) - result = ds3.collect() - self.assertEqual(result[0][0], "1") - self.assertEqual(result[1][0], "2") + # use a flat schema + ds2 = ds.map2(lambda row: row.key * 3) + result = ds2.applySchema(IntegerType()).collect() + self.assertEqual(result[0][0], 0) + self.assertEqual(result[1][0], 3) - func = lambda row: row.key * 3 - ds2 = ds.mapPartitions2(lambda iterator: map(func, iterator)) - ds3 = ds2.schema(IntegerType()) # use a flat schema - result = ds3.collect() - self.assertEqual(result[0][0], 3) - self.assertEqual(result[1][0], 6) + # schema can be inferred automatically + result = ds2.applySchema().collect() + self.assertEqual(result[0][0], 0) + self.assertEqual(result[1][0], 3) - result = ds2.collect() # schema can be inferred automatically - self.assertEqual(result[0][0], 3) - self.assertEqual(result[1][0], 6) + # If no schema is given, by default it's a single binary field struct type. + from pyspark.sql.functions import length + result = ds2.select(length("value")).collect() + self.assertTrue(result[0][0] > 0) + self.assertTrue(result[1][0] > 0) # row count should be corrected even no schema is specified. - self.assertEqual(ds2.count(), 2) + self.assertEqual(ds2.count(), 100) + + # typed operation still works on cached Dataset. + ds3 = ds2.cache().map2(lambda key: key / 3) + self.assertEqual(ds3.count(), 100) + result = ds3.applySchema(IntegerType()).collect() + self.assertEqual(result[0][0], 0) + self.assertEqual(result[1][0], 3) class HiveContextSQLTests(ReusedPySparkTestCase): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 791cc9367d18..83cc20f5640d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -21,7 +21,7 @@ import org.apache.spark.api.python.PythonFunction import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{ObjectType, StructType} +import org.apache.spark.sql.types.{DataType, ObjectType, StructType} /** * A trait for logical operators that apply user defined functions to domain objects. @@ -132,6 +132,17 @@ case class AppendColumns( override def deserializers: Seq[(Expression, Seq[Attribute])] = Seq(deserializer -> child.output) } +case class PythonAppendColumns( + func: PythonFunction, + newColumns: Seq[Attribute], + isFlat: Boolean, + child: LogicalPlan) extends UnaryNode { + + override def output: Seq[Attribute] = child.output ++ newColumns + + override def expressions: Seq[Expression] = Nil +} + /** Factory for constructing new `MapGroups` nodes. */ object MapGroups { def apply[K : Encoder, T : Encoder, U : Encoder]( @@ -172,6 +183,15 @@ case class MapGroups( Seq(keyDeserializer -> groupingAttributes, valueDeserializer -> dataAttributes) } +case class PythonMapGroups( + func: PythonFunction, + groupingExprs: Seq[Expression], + dataAttributes: Seq[Attribute], + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + override def expressions: Seq[Expression] = groupingExprs +} + /** Factory for constructing new `CoGroup` nodes. */ object CoGroup { def apply[Key : Encoder, Left : Encoder, Right : Encoder, Result : Encoder]( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index b4676067eb6c..a8bc49d30001 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1761,6 +1761,10 @@ class DataFrame private[sql]( } } + protected[sql] def pythonMapPartitions(func: PythonFunction): DataFrame = withPlan { + PythonMapPartitions(func, EvaluatePython.schemaOfPickled.toAttributes, logicalPlan) + } + protected[sql] def pythonMapPartitions( func: PythonFunction, schemaJson: String): DataFrame = withPlan { @@ -1768,6 +1772,38 @@ class DataFrame private[sql]( PythonMapPartitions(func, schema.toAttributes, logicalPlan) } + protected[sql] def pythonGroupBy( + func: PythonFunction, + keyTypeJson: String): GroupedPythonDataset = { + val keyType = DataType.fromJson(keyTypeJson) + val isFlat = !keyType.isInstanceOf[StructType] + val keyAttributes = if (isFlat) { + Seq(AttributeReference("key", keyType)()) + } else { + keyType.asInstanceOf[StructType].toAttributes + } + + val inputPlan = queryExecution.analyzed + val withGroupingKey = PythonAppendColumns(func, keyAttributes, isFlat, inputPlan) + val executed = sqlContext.executePlan(withGroupingKey) + + new GroupedPythonDataset( + executed, + withGroupingKey.newColumns, + inputPlan.output, + GroupedData.GroupByType) + } + + protected[sql] def pythonGroupBy(cols: Column*): GroupedPythonDataset = { + new GroupedPythonDataset( + queryExecution, + cols.map(_.expr), + queryExecution.analyzed.output, + GroupedData.GroupByType) + } + + protected[sql] def isPickled(): Boolean = EvaluatePython.isPickled(schema) + /** * Wrap a DataFrame action to track all Spark jobs in the body so that we can connect them with * an execution. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedPythonDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedPythonDataset.scala new file mode 100644 index 000000000000..aa28d2db1b4a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedPythonDataset.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.api.python.PythonFunction +import org.apache.spark.sql.GroupedData.GroupType +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.catalyst.plans.logical.PythonMapGroups +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.python.EvaluatePython +import org.apache.spark.sql.types.{DataType, StructType} + +class GroupedPythonDataset private[sql]( + queryExecution: QueryExecution, + groupingExprs: Seq[Expression], + dataAttributes: Seq[Attribute], + groupType: GroupType) { + + private def sqlContext = queryExecution.sqlContext + + protected[sql] def isPickled(): Boolean = + EvaluatePython.isPickled(queryExecution.analyzed.output.toStructType) + + private def groupedData = + new GroupedData( + new DataFrame(sqlContext, queryExecution), groupingExprs, GroupedData.GroupByType) + + @scala.annotation.varargs + def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { + groupedData.agg(aggExpr, aggExprs: _*) + } + + def agg(exprs: Map[String, String]): DataFrame = groupedData.agg(exprs) + + def agg(exprs: java.util.Map[String, String]): DataFrame = groupedData.agg(exprs) + + @scala.annotation.varargs + def agg(expr: Column, exprs: Column*): DataFrame = groupedData.agg(expr, exprs: _*) + + def count(): DataFrame = groupedData.count() + + @scala.annotation.varargs + def mean(colNames: String*): DataFrame = groupedData.mean(colNames: _*) + + @scala.annotation.varargs + def max(colNames: String*): DataFrame = groupedData.max(colNames: _*) + + @scala.annotation.varargs + def avg(colNames: String*): DataFrame = groupedData.avg(colNames: _*) + + @scala.annotation.varargs + def min(colNames: String*): DataFrame = groupedData.min(colNames: _*) + + @scala.annotation.varargs + def sum(colNames: String*): DataFrame = groupedData.sum(colNames: _*) + + def flatMapGroups(f: PythonFunction, schemaJson: String): DataFrame = { + val schema = DataType.fromJson(schemaJson).asInstanceOf[StructType] + internalFlatMapGroups(f, schema) + } + + def flatMapGroups(f: PythonFunction): DataFrame = { + internalFlatMapGroups(f, EvaluatePython.schemaOfPickled) + } + + private def internalFlatMapGroups(f: PythonFunction, schema: StructType): DataFrame = { + new DataFrame( + sqlContext, + PythonMapGroups( + f, + groupingExprs, + dataAttributes, + schema.toAttributes, + queryExecution.logical)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 9dc5550da452..470dcaa0f6a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -373,6 +373,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { python.BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case logical.PythonMapPartitions(func, output, child) => execution.PythonMapPartitions(func, output, planLater(child)) :: Nil + case logical.PythonAppendColumns(func, newColumns, isFlat, child) => + execution.PythonAppendColumns(func, newColumns, isFlat, planLater(child)) :: Nil + case logical.PythonMapGroups(func, grouping, data, output, child) => + execution.PythonMapGroups(func, grouping, data, output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil case BroadcastHint(child) => planLater(child) :: Nil case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 049bf9a582c3..2814c95adc33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -85,18 +85,23 @@ case class PythonMapPartitions( val inputRDD = child.execute().map(_.copy()) val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) + val isChildPickled = EvaluatePython.isPickled(child.schema) + val isOutputPickled = EvaluatePython.isPickled(schema) inputRDD.mapPartitions { iter => - EvaluatePython.registerPicklers() // register pickler for Row - val pickle = new Pickler - - // Input iterator to Python: input rows are grouped so we send them in batches to Python. - // For each row, add it to the queue. - val inputIterator = iter.grouped(100).map { inputRows => - val toBePickled = inputRows.map { row => - EvaluatePython.toJava(row, child.schema) - }.toArray - pickle.dumps(toBePickled) + val inputIterator = if (isChildPickled) { + iter.map(_.getBinary(0)) + } else { + EvaluatePython.registerPicklers() // register pickler for Row + val pickle = new Pickler + + // Input iterator to Python: input rows are grouped so we send them in batches to Python. + iter.grouped(100).map { inputRows => + val toBePickled = inputRows.map { row => + EvaluatePython.toJava(row, child.schema) + }.toArray + pickle.dumps(toBePickled) + } } val context = TaskContext.get() @@ -115,14 +120,22 @@ case class PythonMapPartitions( reuseWorker ).compute(inputIterator, context.partitionId(), context) - val unpickle = new Unpickler val toUnsafe = UnsafeProjection.create(output, output) - outputIterator.flatMap { pickedResult => - val unpickledBatch = unpickle.loads(pickedResult) - unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala - }.map { result => - toUnsafe(EvaluatePython.fromJava(result, schema).asInstanceOf[InternalRow]) + if (isOutputPickled) { + val row = new GenericMutableRow(1) + outputIterator.map { bytes => + row(0) = bytes + toUnsafe(row) + } + } else { + val unpickle = new Unpickler + outputIterator.flatMap { pickedResult => + val unpickledBatch = unpickle.loads(pickedResult) + unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala + }.map { result => + toUnsafe(EvaluatePython.fromJava(result, schema).asInstanceOf[InternalRow]) + } } } } @@ -157,6 +170,92 @@ case class AppendColumns( } } +case class PythonAppendColumns( + func: PythonFunction, + newColumns: Seq[Attribute], + isFlat: Boolean, + child: SparkPlan) extends UnaryNode { + + override def output: Seq[Attribute] = child.output ++ newColumns + + override def expressions: Seq[Expression] = Nil + + override protected def doExecute(): RDD[InternalRow] = { + val inputRDD = child.execute().map(_.copy()) + val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) + val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) + val newColumnSchema = newColumns.toStructType + val isChildPickled = EvaluatePython.isPickled(child.schema) + + inputRDD.mapPartitionsInternal { iter => + // The queue used to buffer input rows so we can drain it to + // combine input with output from Python. + val queue = new java.util.LinkedList[InternalRow]() + + val inputIterator = if (isChildPickled) { + iter.map { row => + queue.add(row) + row.getBinary(0) + } + } else { + EvaluatePython.registerPicklers() // register pickler for Row + val pickle = new Pickler + + // Input iterator to Python: input rows are grouped so we send them in batches to Python. + // For each row, add it to the queue. + iter.grouped(100).map { inputRows => + val toBePickled = inputRows.map { row => + queue.add(row) + EvaluatePython.toJava(row, child.schema) + }.toArray + pickle.dumps(toBePickled) + } + } + + val context = TaskContext.get() + + // Output iterator for results from Python. + val outputIterator = + new PythonRunner( + func.command, + func.envVars, + func.pythonIncludes, + func.pythonExec, + func.pythonVer, + func.broadcastVars, + func.accumulator, + bufferSize, + reuseWorker + ).compute(inputIterator, context.partitionId(), context) + + val unpickle = new Unpickler + val toUnsafe = UnsafeProjection.create(newColumns, newColumns) + val combiner = GenerateUnsafeRowJoiner.create(child.schema, newColumnSchema) + + val newData = outputIterator.flatMap { pickedResult => + val unpickledBatch = unpickle.loads(pickedResult) + unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala + } + + val newRows = if (isFlat) { + val row = new GenericMutableRow(1) + newData.map { key => + row(0) = EvaluatePython.fromJava(key, newColumns.head.dataType) + toUnsafe(row) + } + } else { + newData.map { key => + toUnsafe(EvaluatePython.fromJava(key, newColumnSchema).asInstanceOf[InternalRow]) + } + } + + newRows.map { newRow => + combiner.join(queue.poll().asInstanceOf[UnsafeRow], newRow) + } + } + } +} + /** * Groups the input rows together and calls the function with each group and an iterator containing * all elements in the group. The result of this function is encoded and flattened before @@ -197,6 +296,90 @@ case class MapGroups( } } +case class PythonMapGroups( + func: PythonFunction, + groupingExprs: Seq[Expression], + dataAttributes: Seq[Attribute], + output: Seq[Attribute], + child: SparkPlan) extends UnaryNode { + + override def expressions: Seq[Expression] = groupingExprs + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(groupingExprs) :: Nil + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(groupingExprs.map(SortOrder(_, Ascending))) + + override protected def doExecute(): RDD[InternalRow] = { + val inputRDD = child.execute().map(_.copy()) + val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) + val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) + val isValuePickled = EvaluatePython.isPickled(dataAttributes.toStructType) + val isOutputPickled = EvaluatePython.isPickled(schema) + + inputRDD.mapPartitionsInternal { iter => + EvaluatePython.registerPicklers() // register pickler for Row + val pickle = new Pickler + val unpickle = new Unpickler + val grouped = GroupedIterator(iter, groupingExprs, child.output) + + grouped.flatMap { case (_, values) => + val inputIterator = if (isValuePickled) { + iter.map(_.getBinary(0)) + } else { + val getValue: InternalRow => InternalRow = if (dataAttributes == child.output) { + identity + } else { + UnsafeProjection.create(dataAttributes, child.output) + } + + // Input iterator to Python: input rows are grouped so we send them in batches to Python. + iter.grouped(100).map { inputRows => + val toBePickled = inputRows.map { row => + EvaluatePython.toJava(getValue(row), child.schema) + }.toArray + pickle.dumps(toBePickled) + } + } + + val context = TaskContext.get() + + // Output iterator for results from Python. + val outputIterator = + new PythonRunner( + func.command, + func.envVars, + func.pythonIncludes, + func.pythonExec, + func.pythonVer, + func.broadcastVars, + func.accumulator, + bufferSize, + reuseWorker + ).compute(inputIterator, context.partitionId(), context) + + val toUnsafe = UnsafeProjection.create(output, output) + + if (isOutputPickled) { + val row = new GenericMutableRow(1) + outputIterator.map { bytes => + row(0) = bytes + toUnsafe(row) + } + } else { + outputIterator.flatMap { pickedResult => + val unpickledBatch = unpickle.loads(pickedResult) + unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala + }.map { result => + toUnsafe(EvaluatePython.fromJava(result, schema).asInstanceOf[InternalRow]) + } + } + } + } + } +} + /** * Co-groups the data from left and right children, and calls the function with each group and 2 * iterators containing all elements in the group from left and right side. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index 8c46516594a2..78aa0cc33fd4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -248,6 +248,17 @@ object EvaluatePython { } } + val schemaOfPickled = { + val metaPickled = new MetadataBuilder().putBoolean("pickled", true).build() + new StructType().add("value", BinaryType, nullable = false, metadata = metaPickled) + } + + def isPickled(schema: StructType): Boolean = schema.length == 1 && { + val field = schema.head + field.dataType == BinaryType && + field.metadata.contains("pickled") && field.metadata.getBoolean("pickled") + } + /** * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by * PySpark. From da77adc296e9cbf48309ed76926ed2ed0354430c Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 19 Feb 2016 11:31:49 +0800 Subject: [PATCH 08/17] improve aggregate --- python/pyspark/rdd.py | 8 +- python/pyspark/sql/dataframe.py | 42 ++++--- python/pyspark/sql/group.py | 116 +++++++++++++++--- python/pyspark/sql/tests.py | 39 ++++++ .../sql/catalyst/expressions/package.scala | 4 +- .../org/apache/spark/sql/DataFrame.scala | 10 +- .../spark/sql/GroupedPythonDataset.scala | 3 +- .../apache/spark/sql/execution/objects.scala | 101 ++++++++------- 8 files changed, 227 insertions(+), 96 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 491abac21ebe..8db38bcf0c7f 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2330,11 +2330,9 @@ def _prepare_for_python_RDD(sc, command, obj=None): return pickled_command, broadcast_vars, env, includes -def _wrap_function(sc, func, deserializer=None, serializer=None, profiler=None): - if deserializer is None: - deserializer = AutoBatchedSerializer(PickleSerializer()) - if serializer is None: - serializer = AutoBatchedSerializer(PickleSerializer()) +def _wrap_function(sc, func, deserializer, serializer, profiler=None): + assert deserializer, "deserializer should not be empty" + assert serializer, "serializer should not be empty" command = (func, profiler, deserializer, serializer) pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command) return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec, diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 0863c3ea9964..efa490d1bbb9 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -28,7 +28,8 @@ from pyspark import since from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix -from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer +from pyspark.serializers import AutoBatchedSerializer, BatchedSerializer, PickleSerializer, \ + UTF8Deserializer, PairDeserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync from pyspark.sql.types import _parse_datatype_json_string @@ -236,9 +237,14 @@ def collect(self): >>> df.collect() [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ + + if self._jdf.isPickled(): + deserializer = PickleSerializer() + else: + deserializer = BatchedSerializer(PickleSerializer()) with SCCallSiteSync(self._sc) as css: port = self._jdf.collectToPython() - return list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) + return list(_load_from_socket(port, deserializer)) @ignore_unicode_prefix @since(1.3) @@ -282,13 +288,16 @@ def map(self, f): @since(2.0) def applySchema(self, schema=None): """ TODO """ - if schema is None: - from pyspark.sql.types import _infer_type, _merge_type - # If no schema is specified, infer it from the whole data set. - jrdd = self._prev_jdf.javaToPython() - rdd = RDD(jrdd, self._sc, BatchedSerializer(PickleSerializer())) - schema = rdd.mapPartitions(self._func).map(_infer_type).reduce(_merge_type) - return PipelinedDataFrame(self, output_schema=schema) + if isinstance(self, PipelinedDataFrame): + if schema is None: + from pyspark.sql.types import _infer_type, _merge_type + # If no schema is specified, infer it from the whole data set. + jrdd = self._prev_jdf.javaToPython() + rdd = RDD(jrdd, self._sc, BatchedSerializer(PickleSerializer())) + schema = rdd.mapPartitions(self._func).map(_infer_type).reduce(_merge_type) + return PipelinedDataFrame(self, output_schema=schema) + else: + return self @ignore_unicode_prefix @since(2.0) @@ -926,7 +935,7 @@ def groupByKey(self, key_func, key_type): wraped_func = _wrap_func(self._sc, self._jdf, f, False) jgd = self._jdf.pythonGroupBy(wraped_func, key_type.json()) from pyspark.sql.group import GroupedData - return GroupedData(jgd, self.sql_ctx, key_func) + return GroupedData(jgd, self.sql_ctx, not isinstance(key_type, StructType)) @since(1.4) def rollup(self, *cols): @@ -1396,6 +1405,7 @@ def __init__(self, prev, func=None, output_schema=None): from pyspark.sql.group import GroupedData if output_schema is None: + # should get it from java side self._schema = StructType().add("binary", BinaryType(), False, {"pickled": True}) else: self._schema = output_schema @@ -1446,7 +1456,7 @@ def _jdf(self): return self._jdf_val def _create_jdf(self, func, schema=None): - wrapped_func = _wrap_func(self._sc, self._prev_jdf, func, schema is None) + wrapped_func = _wrap_func(self._sc, self._prev_jdf, func, schema is None, self._grouped) if schema is None: if self._grouped: return self._prev_jdf.flatMapGroups(wrapped_func) @@ -1460,16 +1470,18 @@ def _create_jdf(self, func, schema=None): return self._prev_jdf.pythonMapPartitions(wrapped_func, schema_string) -def _wrap_func(sc, jdf, func, output_binary): - if jdf.isPickled(): +def _wrap_func(sc, jdf, func, output_binary, input_grouped=False): + if input_grouped: + deserializer = PairDeserializer(PickleSerializer(), PickleSerializer()) + elif jdf.isPickled(): deserializer = PickleSerializer() else: - deserializer = None # the framework will provide a default one + deserializer = AutoBatchedSerializer(PickleSerializer()) if output_binary: serializer = PickleSerializer() else: - serializer = None # the framework will provide a default one + serializer = AutoBatchedSerializer(PickleSerializer()) from pyspark.rdd import _wrap_function return _wrap_function(sc, lambda _, iterator: func(iterator), deserializer, serializer) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 759a4cd9cb4e..a11fbb595ec8 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -15,6 +15,15 @@ # limitations under the License. # +import sys + +if sys.version >= '3': + basestring = unicode = str + long = int + from functools import reduce +else: + from itertools import imap as map + from pyspark import since from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import Column, _to_seq, _to_java_column, _create_column_from_literal @@ -54,25 +63,25 @@ class GroupedData(object): .. versionadded:: 1.3 """ - def __init__(self, jgd, sql_ctx, key_func=None): + def __init__(self, jgd, sql_ctx, flat_key=False): self._jgd = jgd self.sql_ctx = sql_ctx - if key_func is None: - self.key_func = lambda key: key + if flat_key: + self._key_converter = lambda key: key[0] else: - self.key_func = key_func + self._key_converter = lambda key: key @ignore_unicode_prefix @since(2.0) def flatMapGroups(self, func): """ TODO """ - import itertools - key_func = self.key_func + key_converter = self._key_converter - def process(iterator): - first = iterator.next() - key = key_func(first) - return func(key, itertools.chain([first], iterator)) + def process(inputs): + record_converter = lambda record: (key_converter(record[0]), record[1]) + for key, values in GroupedIterator(map(record_converter, inputs)): + for output in func(key, values): + yield output return PipelinedDataFrame(self, process) @@ -217,6 +226,86 @@ def pivot(self, pivot_col, values=None): return GroupedData(jgd, self.sql_ctx) +class GroupedIterator(object): + """ TODO """ + + def __init__(self, inputs): + self.inputs = BufferedIterator(inputs) + self.current_input = inputs.next() + self.current_key = self.current_input[0] + self.current_values = GroupValuesIterator(self) + + def __iter__(self): + return self + + def next(self): + if self.current_values is None: + self._fetch_next_group() + + ret = (self.current_key, self.current_values) + self.current_values = None + return ret + + def _fetch_next_group(self): + if self.current_input is None: + self.current_input = self.inputs.next() + + # Skip to next group, or consume all inputs and throw StopIteration exception. + while self.current_input[0] == self.current_key: + self.current_input = self.inputs.next() + + self.current_key = self.current_input[0] + self.current_values = GroupValuesIterator(self) + + +class GroupValuesIterator(object): + """ TODO """ + + def __init__(self, outter): + self.outter = outter + + def __iter__(self): + return self + + def next(self): + if self.outter.current_input is None: + self._fetch_next_value() + + value = self.outter.current_input[1] + self.outter.current_input = None + return value + + def _fetch_next_value(self): + if self.outter.inputs.head()[0] == self.outter.current_key: + self.outter.current_input = self.outter.inputs.next() + else: + raise StopIteration + + +class BufferedIterator(object): + """ TODO """ + + def __init__(self, iterator): + self.iterator = iterator + self.buffered = None + + def __iter__(self): + return self + + def next(self): + if self.buffered is None: + return self.iterator.next() + else: + item = self.buffered + self.buffered = None + return item + + def head(self): + if self.buffered is None: + self.buffered = self.iterator.next() + return self.buffered + + def _test(): import doctest from pyspark.context import SparkContext @@ -237,13 +326,6 @@ def _test(): Row(course="dotNET", year=2013, earnings=48000), Row(course="Java", year=2013, earnings=30000)]).toDF() - ds = globs['sqlContext'].createDataFrame([(i, i) for i in range(100)], ("key", "value")) - grouped = ds.groupByKey(lambda row: row.key % 5, IntegerType()) - value_sum = lambda rows: sum(map(lambda row: row.value, rows)) - agged = grouped.mapGroups(lambda key, values: str(key) + ":" + str(value_sum(values))) - result = agged.applySchema(StringType()).collect() - raise ValueError(result[0][0]) - (failure_count, test_count) = doctest.testmod( pyspark.sql.group, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 719cb79f6b61..2bdc9942e43b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1188,6 +1188,11 @@ def test_dataset(self): self.assertTrue(result[0][0] > 0) self.assertTrue(result[1][0] > 0) + # If no schema is given, collect will return custom objects instead of rows. + result = ds2.collect() + self.assertEqual(result[0], 0) + self.assertEqual(result[1], 3) + # row count should be corrected even no schema is specified. self.assertEqual(ds2.count(), 100) @@ -1198,6 +1203,40 @@ def test_dataset(self): self.assertEqual(result[0][0], 0) self.assertEqual(result[1][0], 3) + def test_typed_aggregate(self): + data = [(i, i * 2) for i in range(100)] + ds = self.sqlCtx.createDataFrame(data, ("key", "value")) + sum_tuple = lambda values: sum(map(lambda value: value[0] * value[1], values)) + + def get_python_result(data, key_func, agg_func): + data.sort(key=key_func) + expected_result = [] + import itertools + for key, values in itertools.groupby(data, key_func): + expected_result.append(agg_func(key, values)) + return expected_result + + grouped = ds.groupByKey(lambda row: row.key % 5, IntegerType()) + agg_func = lambda key, values: str(key) + ":" + str(sum_tuple(values)) + result = sorted(grouped.mapGroups(agg_func).collect()) + expected_result = get_python_result(data, lambda i: i[0] % 5, agg_func) + self.assertEqual(result, expected_result) + + # We can also call groupByKey on a Dataset of custom objects. + ds2 = ds.map2(lambda row: row.key) + grouped = ds2.groupByKey(lambda i: i % 5, IntegerType()) + agg_func = lambda key, values: str(key) + ":" + str(sum(values)) + result = sorted(grouped.mapGroups(agg_func).collect()) + expected_result = get_python_result(range(100), lambda i: i % 5, agg_func) + self.assertEqual(result, expected_result) + + # We can also apply typed aggregate after structured groupBy, the key is row object. + grouped = ds.groupBy(ds.key % 2, ds.key % 3) + agg_func = lambda key, values: str(key[0]) + str(key[1]) + ":" + str(sum_tuple(values)) + result = sorted(grouped.mapGroups(agg_func).collect()) + expected_result = get_python_result(data, lambda i: (i[0] % 2, i[0] % 3), agg_func) + self.assertEqual(result, expected_result) + class HiveContextSQLTests(ReusedPySparkTestCase): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index f1fa13daa77e..83389ae2ef7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -88,8 +88,6 @@ package object expressions { */ implicit class AttributeSeq(attrs: Seq[Attribute]) { /** Creates a StructType with a schema matching this `Seq[Attribute]`. */ - def toStructType: StructType = { - StructType(attrs.map(a => StructField(a.name, a.dataType, a.nullable))) - } + def toStructType: StructType = StructType.fromAttributes(attrs) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index a8bc49d30001..154a60b0bf97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1750,9 +1750,13 @@ class DataFrame private[sql]( * Converts a JavaRDD to a PythonRDD. */ protected[sql] def javaToPython: JavaRDD[Array[Byte]] = { - val structType = schema // capture it for closure - val rdd = queryExecution.toRdd.map(EvaluatePython.toJava(_, structType)) - EvaluatePython.javaToPython(rdd) + if (EvaluatePython.isPickled(schema)) { + queryExecution.toRdd.map(_.getBinary(0)) + } else { + val structType = schema // capture it for closure + val rdd = queryExecution.toRdd.map(EvaluatePython.toJava(_, structType)) + EvaluatePython.javaToPython(rdd) + } } protected[sql] def collectToPython(): Int = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedPythonDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedPythonDataset.scala index aa28d2db1b4a..010642410bd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedPythonDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedPythonDataset.scala @@ -33,8 +33,7 @@ class GroupedPythonDataset private[sql]( private def sqlContext = queryExecution.sqlContext - protected[sql] def isPickled(): Boolean = - EvaluatePython.isPickled(queryExecution.analyzed.output.toStructType) + protected[sql] def isPickled(): Boolean = EvaluatePython.isPickled(dataAttributes.toStructType) private def groupedData = new GroupedData( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 2814c95adc33..d144115bb277 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection, GenerateUnsafeRowJoiner} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.python.EvaluatePython -import org.apache.spark.sql.types.ObjectType +import org.apache.spark.sql.types.{StructField, StructType, ObjectType} /** * Helper functions for physical operators that work with user defined objects. @@ -315,65 +315,64 @@ case class PythonMapGroups( val inputRDD = child.execute().map(_.copy()) val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) - val isValuePickled = EvaluatePython.isPickled(dataAttributes.toStructType) + + val keySchema = StructType(groupingExprs.map(_.dataType).map(dt => StructField("k", dt))) + val valueSchema = dataAttributes.toStructType + val isValuePickled = EvaluatePython.isPickled(valueSchema) val isOutputPickled = EvaluatePython.isPickled(schema) inputRDD.mapPartitionsInternal { iter => EvaluatePython.registerPicklers() // register pickler for Row val pickle = new Pickler - val unpickle = new Unpickler - val grouped = GroupedIterator(iter, groupingExprs, child.output) - grouped.flatMap { case (_, values) => - val inputIterator = if (isValuePickled) { - iter.map(_.getBinary(0)) + val getKey = UnsafeProjection.create(groupingExprs, child.output) + val getValue: InternalRow => InternalRow = if (dataAttributes == child.output) { + identity + } else { + UnsafeProjection.create(dataAttributes, child.output) + } + + val inputIterator = iter.map { input => + val keyBytes = pickle.dumps(EvaluatePython.toJava(getKey(input), keySchema)) + val valueBytes = if (isValuePickled) { + input.getBinary(0) } else { - val getValue: InternalRow => InternalRow = if (dataAttributes == child.output) { - identity - } else { - UnsafeProjection.create(dataAttributes, child.output) - } - - // Input iterator to Python: input rows are grouped so we send them in batches to Python. - iter.grouped(100).map { inputRows => - val toBePickled = inputRows.map { row => - EvaluatePython.toJava(getValue(row), child.schema) - }.toArray - pickle.dumps(toBePickled) - } + pickle.dumps(EvaluatePython.toJava(getValue(input), valueSchema)) } + keyBytes -> valueBytes + } - val context = TaskContext.get() - - // Output iterator for results from Python. - val outputIterator = - new PythonRunner( - func.command, - func.envVars, - func.pythonIncludes, - func.pythonExec, - func.pythonVer, - func.broadcastVars, - func.accumulator, - bufferSize, - reuseWorker - ).compute(inputIterator, context.partitionId(), context) - - val toUnsafe = UnsafeProjection.create(output, output) - - if (isOutputPickled) { - val row = new GenericMutableRow(1) - outputIterator.map { bytes => - row(0) = bytes - toUnsafe(row) - } - } else { - outputIterator.flatMap { pickedResult => - val unpickledBatch = unpickle.loads(pickedResult) - unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala - }.map { result => - toUnsafe(EvaluatePython.fromJava(result, schema).asInstanceOf[InternalRow]) - } + val context = TaskContext.get() + + // Output iterator for results from Python. + val outputIterator = + new PythonRunner( + func.command, + func.envVars, + func.pythonIncludes, + func.pythonExec, + func.pythonVer, + func.broadcastVars, + func.accumulator, + bufferSize, + reuseWorker + ).compute(inputIterator, context.partitionId(), context) + + val toUnsafe = UnsafeProjection.create(output, output) + + if (isOutputPickled) { + val row = new GenericMutableRow(1) + outputIterator.map { bytes => + row(0) = bytes + toUnsafe(row) + } + } else { + val unpickle = new Unpickler + outputIterator.flatMap { pickedResult => + val unpickledBatch = unpickle.loads(pickedResult) + unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala + }.map { result => + toUnsafe(EvaluatePython.fromJava(result, schema).asInstanceOf[InternalRow]) } } } From a77249208153304f9a4e520b2df730ab6c443e78 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 19 Feb 2016 11:36:39 +0800 Subject: [PATCH 09/17] fix style --- .../src/main/scala/org/apache/spark/sql/execution/objects.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index d144115bb277..509140ca2648 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection, GenerateUnsafeRowJoiner} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.python.EvaluatePython -import org.apache.spark.sql.types.{StructField, StructType, ObjectType} +import org.apache.spark.sql.types.{ObjectType, StructField, StructType} /** * Helper functions for physical operators that work with user defined objects. From 590308ab6c93ae5787a59804767d9a5c1545b6ea Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 19 Feb 2016 14:21:24 +0800 Subject: [PATCH 10/17] add pivot --- .../scala/org/apache/spark/sql/GroupedPythonDataset.scala | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedPythonDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedPythonDataset.scala index 010642410bd2..adc47e14b5cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedPythonDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedPythonDataset.scala @@ -68,6 +68,14 @@ class GroupedPythonDataset private[sql]( @scala.annotation.varargs def sum(colNames: String*): DataFrame = groupedData.sum(colNames: _*) + def pivot(pivotColumn: String): GroupedData = groupedData.pivot(pivotColumn) + + def pivot(pivotColumn: String, values: Seq[Any]): GroupedData = + groupedData.pivot(pivotColumn, values) + + def pivot(pivotColumn: String, values: java.util.List[Any]): GroupedData = + groupedData.pivot(pivotColumn, values) + def flatMapGroups(f: PythonFunction, schemaJson: String): DataFrame = { val schema = DataType.fromJson(schemaJson).asInstanceOf[StructType] internalFlatMapGroups(f, schema) From c883fa6f4e7253f62fb156aec0b4a3b0e1850d72 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 19 Feb 2016 15:07:05 +0800 Subject: [PATCH 11/17] some more tests --- python/pyspark/sql/tests.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 2bdc9942e43b..9a316efe9fb4 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1205,7 +1205,7 @@ def test_dataset(self): def test_typed_aggregate(self): data = [(i, i * 2) for i in range(100)] - ds = self.sqlCtx.createDataFrame(data, ("key", "value")) + ds = self.sqlCtx.createDataFrame(data, ("i", "j")) sum_tuple = lambda values: sum(map(lambda value: value[0] * value[1], values)) def get_python_result(data, key_func, agg_func): @@ -1216,14 +1216,14 @@ def get_python_result(data, key_func, agg_func): expected_result.append(agg_func(key, values)) return expected_result - grouped = ds.groupByKey(lambda row: row.key % 5, IntegerType()) + grouped = ds.groupByKey(lambda row: row.i % 5, IntegerType()) agg_func = lambda key, values: str(key) + ":" + str(sum_tuple(values)) result = sorted(grouped.mapGroups(agg_func).collect()) expected_result = get_python_result(data, lambda i: i[0] % 5, agg_func) self.assertEqual(result, expected_result) # We can also call groupByKey on a Dataset of custom objects. - ds2 = ds.map2(lambda row: row.key) + ds2 = ds.map2(lambda row: row.i) grouped = ds2.groupByKey(lambda i: i % 5, IntegerType()) agg_func = lambda key, values: str(key) + ":" + str(sum(values)) result = sorted(grouped.mapGroups(agg_func).collect()) @@ -1231,12 +1231,20 @@ def get_python_result(data, key_func, agg_func): self.assertEqual(result, expected_result) # We can also apply typed aggregate after structured groupBy, the key is row object. - grouped = ds.groupBy(ds.key % 2, ds.key % 3) + grouped = ds.groupBy(ds.i % 2, ds.i % 3) agg_func = lambda key, values: str(key[0]) + str(key[1]) + ":" + str(sum_tuple(values)) result = sorted(grouped.mapGroups(agg_func).collect()) expected_result = get_python_result(data, lambda i: (i[0] % 2, i[0] % 3), agg_func) self.assertEqual(result, expected_result) + # We can also apply structured aggregate after groupByKey + grouped = ds.groupByKey(lambda row: row.i % 5, IntegerType()) + result = sorted(grouped.sum("j").collect()) + get_sum = lambda key: sum(filter(lambda i: i % 5 == key, range(100))) * 2 + result_row = Row("key", "sum(j)") + expected_result = [result_row(i, get_sum(i)) for i in range(5)] + self.assertEqual(result, expected_result) + class HiveContextSQLTests(ReusedPySparkTestCase): From df53348a891debcd64694ce05983eef3c5edbf23 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 19 Feb 2016 16:24:52 +0800 Subject: [PATCH 12/17] minor fix --- python/pyspark/sql/tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 9a316efe9fb4..700a6235f9bb 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1201,7 +1201,7 @@ def test_dataset(self): self.assertEqual(ds3.count(), 100) result = ds3.applySchema(IntegerType()).collect() self.assertEqual(result[0][0], 0) - self.assertEqual(result[1][0], 3) + self.assertEqual(result[1][0], 1) def test_typed_aggregate(self): data = [(i, i * 2) for i in range(100)] From 97dcac2d4ba4994fc6c9a5167be0c69e724c56bb Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 20 Feb 2016 14:33:17 +0800 Subject: [PATCH 13/17] add import --- python/pyspark/sql/tests.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 700a6235f9bb..0614e52d63ef 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -30,6 +30,13 @@ import time import datetime +if sys.version >= '3': + basestring = unicode = str + long = int + from functools import reduce +else: + from itertools import imap as map + import py4j try: import xmlrunner From 349b119d67f03d1fbd86374555d4c3486841985b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 21 Feb 2016 13:27:47 +0800 Subject: [PATCH 14/17] fix python 3 --- python/pyspark/sql/group.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index a11fbb595ec8..7e344d962437 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -231,13 +231,16 @@ class GroupedIterator(object): def __init__(self, inputs): self.inputs = BufferedIterator(inputs) - self.current_input = inputs.next() + self.current_input = next(inputs) self.current_key = self.current_input[0] self.current_values = GroupValuesIterator(self) def __iter__(self): return self + def __next__(self): + return self.next() + def next(self): if self.current_values is None: self._fetch_next_group() @@ -248,11 +251,11 @@ def next(self): def _fetch_next_group(self): if self.current_input is None: - self.current_input = self.inputs.next() + self.current_input = next(self.inputs) # Skip to next group, or consume all inputs and throw StopIteration exception. while self.current_input[0] == self.current_key: - self.current_input = self.inputs.next() + self.current_input = next(self.inputs) self.current_key = self.current_input[0] self.current_values = GroupValuesIterator(self) @@ -267,6 +270,9 @@ def __init__(self, outter): def __iter__(self): return self + def __next__(self): + return self.next() + def next(self): if self.outter.current_input is None: self._fetch_next_value() @@ -277,7 +283,7 @@ def next(self): def _fetch_next_value(self): if self.outter.inputs.head()[0] == self.outter.current_key: - self.outter.current_input = self.outter.inputs.next() + self.outter.current_input = next(self.outter.inputs) else: raise StopIteration @@ -292,9 +298,12 @@ def __init__(self, iterator): def __iter__(self): return self + def __next__(self): + return self.next() + def next(self): if self.buffered is None: - return self.iterator.next() + return next(self.iterator) else: item = self.buffered self.buffered = None @@ -302,7 +311,7 @@ def next(self): def head(self): if self.buffered is None: - self.buffered = self.iterator.next() + self.buffered = next(self.iterator) return self.buffered From 8c32d311b7189b9a7a227a89d84681e97d2997a3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 22 Feb 2016 10:27:12 +0800 Subject: [PATCH 15/17] small fix --- python/pyspark/sql/tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 0614e52d63ef..7ad83647c70d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1216,7 +1216,7 @@ def test_typed_aggregate(self): sum_tuple = lambda values: sum(map(lambda value: value[0] * value[1], values)) def get_python_result(data, key_func, agg_func): - data.sort(key=key_func) + data = sorted(data, key=key_func) expected_result = [] import itertools for key, values in itertools.groupby(data, key_func): From aec6fc478b8ba9741034276bb7e0f453f51d9c33 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 22 Feb 2016 14:50:03 +0800 Subject: [PATCH 16/17] update --- python/pyspark/sql/dataframe.py | 55 ++++++++++++++++----------------- python/pyspark/sql/tests.py | 43 +++++++++++++------------- 2 files changed, 48 insertions(+), 50 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index efa490d1bbb9..ade93e11021e 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -288,6 +288,9 @@ def map(self, f): @since(2.0) def applySchema(self, schema=None): """ TODO """ + # TODO: should we throw exception instead? + return self + if isinstance(self, PipelinedDataFrame): if schema is None: from pyspark.sql.types import _infer_type, _merge_type @@ -1401,16 +1404,9 @@ class PipelinedDataFrame(DataFrame): """ TODO """ - def __init__(self, prev, func=None, output_schema=None): + def __init__(self, prev, func): from pyspark.sql.group import GroupedData - if output_schema is None: - # should get it from java side - self._schema = StructType().add("binary", BinaryType(), False, {"pickled": True}) - else: - self._schema = output_schema - - self._output_schema = output_schema self._jdf_val = None self.is_cached = False self.sql_ctx = prev.sql_ctx @@ -1424,35 +1420,38 @@ def __init__(self, prev, func=None, output_schema=None): self._prev_jdf = prev._jgd elif not isinstance(prev, PipelinedDataFrame) or prev.is_cached: # This transformation is the first in its stage: + self._grouped = False self._func = func self._prev_jdf = prev._jdf - self._grouped = False else: - if func is None: - # This transformation is applying schema, no need to pipeline the function. - self._func = prev._func - else: - self._func = _pipeline_func(prev._func, func) + self._grouped = prev._grouped + self._func = _pipeline_func(prev._func, func) # maintain the pipeline. self._prev_jdf = prev._prev_jdf - self._grouped = prev._grouped + + def applySchema(self, schema=None): + if schema is None: + from pyspark.sql.types import _infer_type, _merge_type + # If no schema is specified, infer it from the whole data set. + jrdd = self._prev_jdf.javaToPython() + rdd = RDD(jrdd, self._sc, BatchedSerializer(PickleSerializer())) + schema = rdd.mapPartitions(self._func).map(_infer_type).reduce(_merge_type) + + if isinstance(schema, StructType): + to_rows = lambda iterator: map(schema.toInternal, iterator) + else: + data_type = schema + schema = StructType().add("value", data_type) + to_row = lambda obj: (data_type.toInternal(obj), ) + to_rows = lambda iterator: map(to_row, iterator) + + jdf = self._create_jdf(_pipeline_func(self._func, to_rows), schema) + return DataFrame(jdf, self.sql_ctx) @property def _jdf(self): if self._jdf_val is None: - if self._output_schema is None: - self._jdf_val = self._create_jdf(self._func) - elif isinstance(self._output_schema, StructType): - schema = self._output_schema - to_row = lambda iterator: map(schema.toInternal, iterator) - self._jdf_val = self._create_jdf(_pipeline_func(self._func, to_row), schema) - else: - data_type = self._output_schema - schema = StructType().add("value", data_type) - converter = lambda obj: (data_type.toInternal(obj), ) - to_row = lambda iterator: map(converter, iterator) - self._jdf_val = self._create_jdf(_pipeline_func(self._func, to_row), schema) - + self._jdf_val = self._create_jdf(self._func) return self._jdf_val def _create_jdf(self, func, schema=None): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 7ad83647c70d..314683404b33 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1161,54 +1161,53 @@ def test_functions_broadcast(self): broadcast(df1)._jdf.queryExecution().executedPlan() def test_dataset(self): - ds = self.sqlCtx.createDataFrame([(i, str(i)) for i in range(100)], ("key", "value")) + data = [(i, str(i)) for i in range(100)] + ds = self.sqlCtx.createDataFrame(data, ("key", "value")) + + def check_result(result, f): + expected_result = [] + for k, v in data: + expected_result.append(f(k, v)) + self.assertEqual(result, expected_result) # convert row to python dict ds2 = ds.map2(lambda row: {"key": row.key + 1, "value": row.value}) schema = StructType().add("key", IntegerType()).add("value", StringType()) ds3 = ds2.applySchema(schema) result = ds3.select("key").collect() - # test first 2 elements - self.assertEqual(result[0][0], 1) - self.assertEqual(result[1][0], 2) + check_result(result, lambda k, v: Row(key=k + 1)) # use a different but compatible schema schema = StructType().add("value", StringType()) result = ds2.applySchema(schema).collect() - self.assertEqual(result[0][0], "0") - self.assertEqual(result[1][0], "1") + check_result(result, lambda k, v: Row(value=v)) # use a flat schema ds2 = ds.map2(lambda row: row.key * 3) result = ds2.applySchema(IntegerType()).collect() - self.assertEqual(result[0][0], 0) - self.assertEqual(result[1][0], 3) + check_result(result, lambda k, v: Row(value=k * 3)) # schema can be inferred automatically - result = ds2.applySchema().collect() - self.assertEqual(result[0][0], 0) - self.assertEqual(result[1][0], 3) + result = ds.map2(lambda row: row.key + 10).applySchema().collect() + check_result(result, lambda k, v: Row(value=k + 10)) # If no schema is given, by default it's a single binary field struct type. from pyspark.sql.functions import length result = ds2.select(length("value")).collect() - self.assertTrue(result[0][0] > 0) - self.assertTrue(result[1][0] > 0) + self.assertEqual(len(result), 100) # If no schema is given, collect will return custom objects instead of rows. - result = ds2.collect() - self.assertEqual(result[0], 0) - self.assertEqual(result[1], 3) + result = ds.map2(lambda row: row.value + "#").collect() + check_result(result, lambda k, v: v + "#") # row count should be corrected even no schema is specified. - self.assertEqual(ds2.count(), 100) + self.assertEqual(ds.map2(lambda row: row.key + 1).count(), 100) - # typed operation still works on cached Dataset. - ds3 = ds2.cache().map2(lambda key: key / 3) + # call cache() in the middle of 2 typed operations. + ds3 = ds.map2(lambda row: row.key * 2).cache().map2(lambda key: key + 1) self.assertEqual(ds3.count(), 100) - result = ds3.applySchema(IntegerType()).collect() - self.assertEqual(result[0][0], 0) - self.assertEqual(result[1][0], 1) + result = ds3.collect() + check_result(result, lambda k, v: k * 2 + 1) def test_typed_aggregate(self): data = [(i, i * 2) for i in range(100)] From 1095d7f8c217ec006dd3b538d4d49da2cf5a287d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 23 Feb 2016 09:24:06 +0800 Subject: [PATCH 17/17] small cleanup --- python/pyspark/sql/dataframe.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index ade93e11021e..b5260224c2ee 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -291,17 +291,6 @@ def applySchema(self, schema=None): # TODO: should we throw exception instead? return self - if isinstance(self, PipelinedDataFrame): - if schema is None: - from pyspark.sql.types import _infer_type, _merge_type - # If no schema is specified, infer it from the whole data set. - jrdd = self._prev_jdf.javaToPython() - rdd = RDD(jrdd, self._sc, BatchedSerializer(PickleSerializer())) - schema = rdd.mapPartitions(self._func).map(_infer_type).reduce(_merge_type) - return PipelinedDataFrame(self, output_schema=schema) - else: - return self - @ignore_unicode_prefix @since(2.0) def mapPartitions2(self, func):