diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 76fbb0c9aa4c..0997eedf8593 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -18,17 +18,19 @@ import sys import warnings import random +from itertools import chain if sys.version >= '3': basestring = unicode = str long = int from functools import reduce else: - from itertools import imap as map + from itertools import imap as map, ifilter as filter from pyspark import since from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix -from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer +from pyspark.serializers import AutoBatchedSerializer, BatchedSerializer, \ + PickleSerializer, UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync from pyspark.sql.types import _parse_datatype_json_string @@ -36,7 +38,7 @@ from pyspark.sql.readwriter import DataFrameWriter from pyspark.sql.types import * -__all__ = ["DataFrame", "DataFrameNaFunctions", "DataFrameStatFunctions"] +__all__ = ["DataFrame", "Dataset", "DataFrameNaFunctions", "DataFrameStatFunctions"] class DataFrame(object): @@ -69,21 +71,32 @@ class DataFrame(object): """ def __init__(self, jdf, sql_ctx): - self._jdf = jdf + if jdf is not None: + self._jdf = jdf self.sql_ctx = sql_ctx self._sc = sql_ctx and sql_ctx._sc self.is_cached = False self._schema = None # initialized lazily self._lazy_rdd = None + def _deserializer(self): + if self._jdf.isOutputPickled(): + # If the underlying java DataFrame's output is pickled, which means the query + # engine don't know the real schema of the data and just keep the pickled binary + # for each custom object(no batch). + # So we need to use non-batched deserializer for this DataFrame. + return PickleSerializer() + else: + return BatchedSerializer(PickleSerializer()) + @property @since(1.3) def rdd(self): - """Returns the content as an :class:`pyspark.RDD` of :class:`Row`. + """Returns the content as an :class:`pyspark.RDD` of :class:`Row` or custom object. """ if self._lazy_rdd is None: jrdd = self._jdf.javaToPython() - self._lazy_rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer())) + self._lazy_rdd = RDD(jrdd, self.sql_ctx._sc, self._deserializer()) return self._lazy_rdd @property @@ -232,14 +245,14 @@ def count(self): @ignore_unicode_prefix @since(1.3) def collect(self): - """Returns all the records as a list of :class:`Row`. + """Returns all the records as a list. >>> df.collect() [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ with SCCallSiteSync(self._sc) as css: port = self._jdf.collectToPython() - return list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) + return list(_load_from_socket(port, self._deserializer())) @ignore_unicode_prefix @since(1.3) @@ -257,7 +270,7 @@ def limit(self, num): @ignore_unicode_prefix @since(1.3) def take(self, num): - """Returns the first ``num`` rows as a :class:`list` of :class:`Row`. + """Returns the first ``num`` records as a :class:`list`. >>> df.take(2) [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] @@ -265,45 +278,86 @@ def take(self, num): with SCCallSiteSync(self._sc) as css: port = self._sc._jvm.org.apache.spark.sql.execution.python.EvaluatePython.takeAndServe( self._jdf, num) - return list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) + return list(_load_from_socket(port, self._deserializer())) + + @ignore_unicode_prefix + @since(2.0) + def applySchema(self, schema=None): + """Returns a new :class:`DataFrame` by appling the given schema, or infer the schema + by all of the records if no schema is given. + + It is only allowed to apply schema for DataFrame which is returned by typed operations, + e.g. map, flatMap, etc. And the record type of the schema-applied DataFrame will be row. + + >>> ds = df.map(lambda row: row.name) + >>> ds.collect() + [u'Alice', u'Bob'] + >>> ds.schema + StructType(List(StructField(value,BinaryType,false))) + >>> ds2 = ds.applySchema(StringType()) + >>> ds2.collect() + [Row(value=u'Alice'), Row(value=u'Bob')] + >>> ds2.schema + StructType(List(StructField(value,StringType,true))) + >>> ds3 = ds.applySchema() + >>> ds3.collect() + [Row(value=u'Alice'), Row(value=u'Bob')] + >>> ds3.schema + StructType(List(StructField(value,StringType,true))) + """ + msg = "Cannot apply schema to a DataFrame which is not returned by typed operations" + raise Exception(msg) @ignore_unicode_prefix @since(1.3) def map(self, f): - """ Returns a new :class:`RDD` by applying a the ``f`` function to each :class:`Row`. + """ Returns a new :class:`DataFrame` by applying a the ``f`` function to each record. - This is a shorthand for ``df.rdd.map()``. + .. versionchanged:: 2.0 + Now it returns a :class:`DataFrame` instead of a :class:`RDD`. + The schema of returned :class:`DataFrame` is a single binary field struct type, please + call `applySchema` to set the corrected schema before apply structured operations, e.g. + select, sort, groupBy, etc. >>> df.map(lambda p: p.name).collect() [u'Alice', u'Bob'] """ - return self.rdd.map(f) + return self.mapPartitions(lambda iterator: map(f, iterator)) @ignore_unicode_prefix @since(1.3) def flatMap(self, f): - """ Returns a new :class:`RDD` by first applying the ``f`` function to each :class:`Row`, + """ Returns a new :class:`DataFrame` by first applying the ``f`` function to each record, and then flattening the results. - This is a shorthand for ``df.rdd.flatMap()``. + .. versionchanged:: 2.0 + Now it returns a :class:`DataFrame` instead of a :class:`RDD`. + The schema of returned :class:`DataFrame` is a single binary field struct type, please + call `applySchema` to set the corrected schema before apply structured operations, e.g. + select, sort, groupBy, etc. >>> df.flatMap(lambda p: p.name).collect() [u'A', u'l', u'i', u'c', u'e', u'B', u'o', u'b'] """ - return self.rdd.flatMap(f) + return self.mapPartitions(lambda iterator: chain.from_iterable(map(f, iterator))) + @ignore_unicode_prefix @since(1.3) - def mapPartitions(self, f, preservesPartitioning=False): - """Returns a new :class:`RDD` by applying the ``f`` function to each partition. + def mapPartitions(self, f): + """Returns a new :class:`DataFrame` by applying the ``f`` function to each partition. - This is a shorthand for ``df.rdd.mapPartitions()``. + .. versionchanged:: 2.0 + Now it returns a :class:`DataFrame` instead of a :class:`RDD`, the + `preservesPartitioning` parameter is removed. + The schema of returned :class:`DataFrame` is a single binary field struct type, please + call `applySchema` to set the corrected schema before apply structured operations, e.g. + select, sort, groupBy, etc. - >>> rdd = sc.parallelize([1, 2, 3, 4], 4) - >>> def f(iterator): yield 1 - >>> rdd.mapPartitions(f).sum() - 4 + >>> f = lambda iterator: map(lambda i: 1, iterator) + >>> df.mapPartitions(f).collect() + [1, 1] """ - return self.rdd.mapPartitions(f, preservesPartitioning) + return PipelinedDataFrame(self, f) @since(1.3) def foreach(self, f): @@ -315,7 +369,7 @@ def foreach(self, f): ... print(person.name) >>> df.foreach(f) """ - return self.rdd.foreach(f) + self.rdd.foreach(f) @since(1.3) def foreachPartition(self, f): @@ -328,7 +382,7 @@ def foreachPartition(self, f): ... print(person.name) >>> df.foreachPartition(f) """ - return self.rdd.foreachPartition(f) + self.rdd.foreachPartition(f) @since(1.3) def cache(self): @@ -745,7 +799,7 @@ def head(self, n=None): :param n: int, default 1. Number of rows to return. :return: If n is greater than 1, return a list of :class:`Row`. - If n is 1, return a single Row. + If n is None, return a single Row. >>> df.head() Row(age=2, name=u'Alice') @@ -843,13 +897,20 @@ def selectExpr(self, *expr): @ignore_unicode_prefix @since(1.3) def filter(self, condition): - """Filters rows using the given condition. + """Filters records using the given condition. :func:`where` is an alias for :func:`filter`. :param condition: a :class:`Column` of :class:`types.BooleanType` or a string of SQL expression. + .. versionchanged:: 2.0 + Also allows condition parameter to be a function that takes record as input and + returns boolean. + The schema of returned :class:`DataFrame` is a single binary field struct type, please + call `applySchema` to set the corrected schema before apply structured operations, e.g. + select, sort, groupBy, etc. + >>> df.filter(df.age > 3).collect() [Row(age=5, name=u'Bob')] >>> df.where(df.age == 2).collect() @@ -859,14 +920,20 @@ def filter(self, condition): [Row(age=5, name=u'Bob')] >>> df.where("age = 2").collect() [Row(age=2, name=u'Alice')] + + >>> df.filter(lambda row: row.age > 3).collect() + [Row(age=5, name=u'Bob')] + >>> df.map(lambda row: row.age).filter(lambda age: age > 3).collect() + [5] """ if isinstance(condition, basestring): - jdf = self._jdf.filter(condition) + return DataFrame(self._jdf.filter(condition), self.sql_ctx) elif isinstance(condition, Column): - jdf = self._jdf.filter(condition._jc) + return DataFrame(self._jdf.filter(condition._jc), self.sql_ctx) + elif hasattr(condition, '__call__'): + return self.mapPartitions(lambda iterator: filter(condition, iterator)) else: raise TypeError("condition should be string or Column") - return DataFrame(jdf, self.sql_ctx) where = filter @@ -1404,6 +1471,83 @@ def toPandas(self): drop_duplicates = dropDuplicates +Dataset = DataFrame + + +class PipelinedDataFrame(DataFrame): + + """ + Pipelined typed operations on :class:`DataFrame`: + + >>> df.map(lambda row: 2 * row.age).cache().map(lambda i: 2 * i).collect() + [8, 20] + >>> df.map(lambda row: 2 * row.age).map(lambda i: 2 * i).collect() + [8, 20] + """ + + def __init__(self, prev, func): + super(PipelinedDataFrame, self).__init__(None, prev.sql_ctx) + self._jdf_val = None + if not isinstance(prev, PipelinedDataFrame) or prev.is_cached: + # This is the beginning of this pipeline. + self._func = func + self._prev_jdf = prev._jdf + else: + self._func = _pipeline_func(prev._func, func) + # maintain the pipeline. + self._prev_jdf = prev._prev_jdf + + def applySchema(self, schema=None): + if schema is None: + from pyspark.sql.types import _infer_type, _merge_type + # If no schema is specified, infer it from the whole data set. + jrdd = self._prev_jdf.javaToPython() + rdd = RDD(jrdd, self._sc, BatchedSerializer(PickleSerializer())) + schema = rdd.mapPartitions(self._func).map(_infer_type).reduce(_merge_type) + + if isinstance(schema, StructType): + to_rows = lambda iterator: map(schema.toInternal, iterator) + else: + data_type = schema + schema = StructType().add("value", data_type) + to_row = lambda obj: (data_type.toInternal(obj), ) + to_rows = lambda iterator: map(to_row, iterator) + + wrapped_func = self._wrap_func(_pipeline_func(self._func, to_rows), False) + jdf = self._prev_jdf.pythonMapPartitions(wrapped_func, schema.json()) + return DataFrame(jdf, self.sql_ctx) + + @property + def _jdf(self): + if self._jdf_val is None: + wrapped_func = self._wrap_func(self._func, True) + self._jdf_val = self._prev_jdf.pythonMapPartitions(wrapped_func) + return self._jdf_val + + def _wrap_func(self, func, output_binary): + if self._prev_jdf.isOutputPickled(): + deserializer = PickleSerializer() + else: + deserializer = AutoBatchedSerializer(PickleSerializer()) + + if output_binary: + serializer = PickleSerializer() + else: + serializer = AutoBatchedSerializer(PickleSerializer()) + + from pyspark.rdd import _wrap_function + return _wrap_function(self._sc, lambda _, iterator: func(iterator), + deserializer, serializer) + + +def _pipeline_func(prev_func, next_func): + """ + Pipeline 2 functions into one, while each of these 2 functions takes an iterator and + returns an iterator. + """ + return lambda iterator: next_func(prev_func(iterator)) + + def _to_scala_map(sc, jm): """ Convert a dict into a JVM Map. diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index ee734cb43928..c91ca3573415 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -27,7 +27,7 @@ def dfapi(f): def _api(self): name = f.__name__ - jdf = getattr(self._jdf, name)() + jdf = getattr(self._jgd, name)() return DataFrame(jdf, self.sql_ctx) _api.__name__ = f.__name__ _api.__doc__ = f.__doc__ @@ -37,7 +37,7 @@ def _api(self): def df_varargs_api(f): def _api(self, *args): name = f.__name__ - jdf = getattr(self._jdf, name)(_to_seq(self.sql_ctx._sc, args)) + jdf = getattr(self._jgd, name)(_to_seq(self.sql_ctx._sc, args)) return DataFrame(jdf, self.sql_ctx) _api.__name__ = f.__name__ _api.__doc__ = f.__doc__ @@ -54,8 +54,8 @@ class GroupedData(object): .. versionadded:: 1.3 """ - def __init__(self, jdf, sql_ctx): - self._jdf = jdf + def __init__(self, jgd, sql_ctx): + self._jgd = jgd self.sql_ctx = sql_ctx @ignore_unicode_prefix @@ -83,11 +83,11 @@ def agg(self, *exprs): """ assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): - jdf = self._jdf.agg(exprs[0]) + jdf = self._jgd.agg(exprs[0]) else: # Columns assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" - jdf = self._jdf.agg(exprs[0]._jc, + jdf = self._jgd.agg(exprs[0]._jc, _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]])) return DataFrame(jdf, self.sql_ctx) @@ -187,9 +187,9 @@ def pivot(self, pivot_col, values=None): [Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)] """ if values is None: - jgd = self._jdf.pivot(pivot_col) + jgd = self._jgd.pivot(pivot_col) else: - jgd = self._jdf.pivot(pivot_col, values) + jgd = self._jgd.pivot(pivot_col, values) return GroupedData(jgd, self.sql_ctx) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 90fd7696910e..d782a105a5c3 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -30,6 +30,11 @@ import time import datetime +from itertools import chain + +if sys.version < '3': + from itertools import imap as map, ifilter as filter + import py4j try: import xmlrunner @@ -346,7 +351,7 @@ def test_basic_functions(self): def test_apply_schema_to_row(self): df = self.sqlCtx.read.json(self.sc.parallelize(["""{"a":2}"""])) - df2 = self.sqlCtx.createDataFrame(df.map(lambda x: x), df.schema) + df2 = self.sqlCtx.createDataFrame(df.rdd.map(lambda x: x), df.schema) self.assertEqual(df.collect(), df2.collect()) rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) @@ -1178,6 +1183,74 @@ def test_functions_broadcast(self): # planner should not crash without a join broadcast(df1)._jdf.queryExecution().executedPlan() + def test_basic_typed_operations(self): + data = [(i, str(i)) for i in range(100)] + ds = self.sqlCtx.createDataFrame(data, ("key", "value")) + + def check_result(result, f): + expected_result = [] + for k, v in data: + expected_result.append(f(k, v)) + self.assertEqual(result, expected_result) + + # convert row to python dict + ds2 = ds.map(lambda row: {"key": row.key + 1, "value": row.value}) + schema = StructType().add("key", IntegerType()).add("value", StringType()) + ds3 = ds2.applySchema(schema) + result = ds3.select("key").collect() + check_result(result, lambda k, v: Row(key=k + 1)) + + # use a different but compatible schema + schema = StructType().add("value", StringType()) + result = ds2.applySchema(schema).collect() + check_result(result, lambda k, v: Row(value=v)) + + # use a flat schema + ds2 = ds.map(lambda row: row.key * 3) + result = ds2.applySchema(IntegerType()).collect() + check_result(result, lambda k, v: Row(value=k * 3)) + + # schema can be inferred automatically + result = ds.map(lambda row: row.key + 10).applySchema().collect() + check_result(result, lambda k, v: Row(value=k + 10)) + + # If no schema is given, by default it's a single binary field struct type. + from pyspark.sql.functions import length + result = ds2.select(length("value")).collect() + self.assertEqual(len(result), 100) + + # If no schema is given, collect will return custom objects instead of rows. + result = ds.map(lambda row: row.value + "#").collect() + check_result(result, lambda k, v: v + "#") + + # cannot appply schema to Dataset not returned by typed operations. + msg = "Cannot apply schema to a DataFrame which is not returned by typed operations" + self.assertRaisesRegexp(Exception, msg, lambda: ds.applySchema()) + + # row count should be corrected even no schema is specified. + self.assertEqual(ds.map(lambda row: row.key + 1).count(), 100) + + # call cache() in the middle of 2 typed operations. + ds2 = ds.map(lambda row: row.key * 2).cache().map(lambda key: key + 1) + self.assertEqual(ds2.count(), 100) + result = ds2.collect() + check_result(result, lambda k, v: k * 2 + 1) + + # other typed operations + ds2 = ds.map(lambda row: row.key * 2) + + result = ds2.flatMap(lambda i: iter([i, i + 1])).collect() + expected_result = chain.from_iterable(map(lambda i: [i * 2, i * 2 + 1], range(100))) + self.assertEqual(result, list(expected_result)) + + result = ds2.mapPartitions(lambda it: map(lambda i: i + 1, it)).collect() + expected_result = map(lambda i: i * 2 + 1, range(100)) + self.assertEqual(result, list(expected_result)) + + result = ds2.filter(lambda i: i > 33).collect() + expected_result = filter(lambda i: i > 33, map(lambda i: i * 2, range(100))) + self.assertEqual(result, list(expected_result)) + class HiveContextSQLTests(ReusedPySparkTestCase): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index f1fa13daa77e..83389ae2ef7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -88,8 +88,6 @@ package object expressions { */ implicit class AttributeSeq(attrs: Seq[Attribute]) { /** Creates a StructType with a schema matching this `Seq[Attribute]`. */ - def toStructType: StructType = { - StructType(attrs.map(a => StructField(a.name, a.dataType, a.nullable))) - } + def toStructType: StructType = StructType.fromAttributes(attrs) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 3f97662957b8..da7f81c78546 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -208,8 +208,6 @@ case class CoGroup( left: LogicalPlan, right: LogicalPlan) extends BinaryNode with ObjectOperator { - override def producedAttributes: AttributeSet = outputSet - override def deserializers: Seq[(Expression, Seq[Attribute])] = // The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to resolve // the `keyDeserializer` based on either of them, here we pick the left one. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 5f5b7f4c19cf..bb9f8d950279 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -27,7 +27,7 @@ import com.fasterxml.jackson.core.JsonFactory import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD -import org.apache.spark.api.python.PythonRDD +import org.apache.spark.api.python.{PythonFunction, PythonRDD} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis._ @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.execution.{ExplainCommand, FileRelation, LogicalRDD, Queryable, QueryExecution, SQLExecution} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator -import org.apache.spark.sql.execution.python.EvaluatePython +import org.apache.spark.sql.execution.python.{EvaluatePython, LogicalMapPartitions} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel @@ -1729,9 +1729,13 @@ class DataFrame private[sql]( * Converts a JavaRDD to a PythonRDD. */ protected[sql] def javaToPython: JavaRDD[Array[Byte]] = { - val structType = schema // capture it for closure - val rdd = queryExecution.toRdd.map(EvaluatePython.toJava(_, structType)) - EvaluatePython.javaToPython(rdd) + if (isOutputPickled) { + queryExecution.toRdd.map(_.getBinary(0)) + } else { + val structType = schema // capture it for closure + val rdd = queryExecution.toRdd.map(EvaluatePython.toJava(_, structType)) + EvaluatePython.javaToPython(rdd) + } } protected[sql] def collectToPython(): Int = { @@ -1740,6 +1744,19 @@ class DataFrame private[sql]( } } + protected[sql] def isOutputPickled: Boolean = EvaluatePython.schemaOfPickled == schema + + protected[sql] def pythonMapPartitions(func: PythonFunction): DataFrame = withPlan { + LogicalMapPartitions(func, EvaluatePython.schemaOfPickled.toAttributes, logicalPlan) + } + + protected[sql] def pythonMapPartitions( + func: PythonFunction, + schemaJson: String): DataFrame = withPlan { + val schema = DataType.fromJson(schemaJson).asInstanceOf[StructType] + LogicalMapPartitions(func, schema.toAttributes, logicalPlan) + } + /** * Wrap a DataFrame action to track all Spark jobs in the body so that we can connect them with * an execution. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index dd8c96d5fa1d..3101182114e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -329,6 +329,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { f, keyObj, lObj, rObj, out, lGroup, rGroup, lAttr, rAttr, planLater(left), planLater(right)) :: Nil + case execution.python.LogicalMapPartitions(f, output, child) => + execution.python.PhysicalMapPartitions(f, output, planLater(child)) :: Nil + case logical.Repartition(numPartitions, shuffle, child) => if (shuffle) { ShuffleExchange(RoundRobinPartitioning(numPartitions), planLater(child)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index 8c46516594a2..837a8aec8abb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -55,12 +55,16 @@ object EvaluatePython { new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)()) def takeAndServe(df: DataFrame, n: Int): Int = { - registerPicklers() + val rows = df.queryExecution.executedPlan.executeTake(n).iterator + val iter = if (df.isOutputPickled) { + rows.map(_.getBinary(0)) + } else { + registerPicklers() + new SerDeUtil.AutoBatchedPickler( + rows.map { row => EvaluatePython.toJava(row, df.schema) } + ) + } df.withNewExecutionId { - val iter = new SerDeUtil.AutoBatchedPickler( - df.queryExecution.executedPlan.executeTake(n).iterator.map { row => - EvaluatePython.toJava(row, df.schema) - }) PythonRDD.serveIterator(iter, s"serve-DataFrame") } } @@ -258,4 +262,12 @@ object EvaluatePython { new SerDeUtil.AutoBatchedPickler(iter) } } + + /** + * The default schema for Python Dataset which is returned by typed operation. + */ + val schemaOfPickled = { + val metaPickled = new MetadataBuilder().putBoolean("pickled", true).build() + new StructType().add("value", BinaryType, nullable = false, metadata = metaPickled) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/LogicalMapPartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/LogicalMapPartitions.scala new file mode 100644 index 000000000000..561f49f4b215 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/LogicalMapPartitions.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import org.apache.spark.api.python.PythonFunction +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode} + +/** + * A relation produced by applying the given python function to each partition of the `child`. + */ +case class LogicalMapPartitions( + func: PythonFunction, + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + + override def expressions: Seq[Expression] = Nil + + override def references: AttributeSet = child.outputSet +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PhysicalMapPartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PhysicalMapPartitions.scala new file mode 100644 index 000000000000..c8b3a0bbf153 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PhysicalMapPartitions.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import scala.collection.JavaConverters._ + +import net.razorvine.pickle.{Pickler, Unpickler} + +import org.apache.spark.TaskContext +import org.apache.spark.api.python.{PythonFunction, PythonRunner} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GenericMutableRow, UnsafeProjection} +import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} + +/** + * Launches a Python runner, send all rows to it, then apply the given function to each row in the + * launched Python runner, and send the results back. + * + * Note that if the schema of this plan equals to [[EvaluatePython.schemaOfPickled]], it means the + * result is not row, and we can't understand the data without the real schema, so we will just wrap + * the pickled binary as a single field row. + * If the schema of child plan equals to [[EvaluatePython.schemaOfPickled]], it means the input data + * is a single field row with pickled binary data, so we will just get the binary and send to Python + * runner, without serializing it. + */ +case class PhysicalMapPartitions( + func: PythonFunction, + output: Seq[Attribute], + child: SparkPlan) extends UnaryNode { + + override def expressions: Seq[Expression] = Nil + + override protected def doExecute(): RDD[InternalRow] = { + val inputRDD = child.execute().map(_.copy()) + val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) + val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) + val isChildPickled = EvaluatePython.schemaOfPickled == child.schema + val isOutputPickled = EvaluatePython.schemaOfPickled == schema + + inputRDD.mapPartitions { iter => + val inputIterator = if (isChildPickled) { + iter.map(_.getBinary(0)) + } else { + EvaluatePython.registerPicklers() // register pickler for Row + val pickle = new Pickler + + // Input iterator to Python: input rows are grouped so we send them in batches to Python. + iter.grouped(100).map { inputRows => + val toBePickled = inputRows.map { row => + EvaluatePython.toJava(row, child.schema) + }.toArray + pickle.dumps(toBePickled) + } + } + + val context = TaskContext.get() + + // Output iterator for results from Python. + val outputIterator = + new PythonRunner( + func, + bufferSize, + reuseWorker + ).compute(inputIterator, context.partitionId(), context) + + val toUnsafe = UnsafeProjection.create(output, output) + + if (isOutputPickled) { + val row = new GenericMutableRow(1) + outputIterator.map { bytes => + row(0) = bytes + toUnsafe(row) + } + } else { + val unpickle = new Unpickler + outputIterator.flatMap { pickedResult => + val unpickledBatch = unpickle.loads(pickedResult) + unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala + }.map { result => + toUnsafe(EvaluatePython.fromJava(result, schema).asInstanceOf[InternalRow]) + } + } + } + } +}