diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index 9c760e3527be..ba4c4feec75c 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -51,11 +51,12 @@ from pyspark.sql.group import GroupedData from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter from pyspark.sql.window import Window, WindowSpec +from pyspark.sql.cogroup import CoGroupedData __all__ = [ 'SparkSession', 'SQLContext', 'UDFRegistration', 'DataFrame', 'GroupedData', 'Column', 'Catalog', 'Row', 'DataFrameNaFunctions', 'DataFrameStatFunctions', 'Window', 'WindowSpec', - 'DataFrameReader', 'DataFrameWriter' + 'DataFrameReader', 'DataFrameWriter', 'CoGroupedData' ] diff --git a/python/pyspark/sql/cogroup.py b/python/pyspark/sql/cogroup.py index 9b725e4bafe7..ef87e703bce1 100644 --- a/python/pyspark/sql/cogroup.py +++ b/python/pyspark/sql/cogroup.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import sys from pyspark import since from pyspark.rdd import PythonEvalType @@ -43,9 +44,9 @@ def apply(self, udf): as a `DataFrame`. The user-defined function should take two `pandas.DataFrame` and return another - `pandas.DataFrame`. For each side of the cogroup, all columns are passed together - as a `pandas.DataFrame` to the user-function and the returned `pandas.DataFrame` - are combined as a :class:`DataFrame`. + `pandas.DataFrame`. For each side of the cogroup, all columns are passed together as a + `pandas.DataFrame` to the user-function and the returned `pandas.DataFrame` are combined as + a :class:`DataFrame`. The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the returnType of the pandas udf. @@ -61,15 +62,16 @@ def apply(self, udf): >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df1 = spark.createDataFrame( - ... [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)], - ... ("time", "id", "v1")) + ... [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)], + ... ("time", "id", "v1")) >>> df2 = spark.createDataFrame( - ... [(20000101, 1, "x"), (20000101, 2, "y")], - ... ("time", "id", "v2")) - >>> @pandas_udf("time int, id int, v1 double, v2 string", PandasUDFType.COGROUPED_MAP) + ... [(20000101, 1, "x"), (20000101, 2, "y")], + ... ("time", "id", "v2")) + >>> @pandas_udf("time int, id int, v1 double, v2 string", + ... PandasUDFType.COGROUPED_MAP) # doctest: +SKIP ... def asof_join(l, r): ... return pd.merge_asof(l, r, on="time", by="id") - >>> df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show() + >>> df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show() # doctest: +SKIP +--------+---+---+---+ | time| id| v1| v2| +--------+---+---+---+ @@ -79,6 +81,27 @@ def apply(self, udf): |20000102| 2|4.0| y| +--------+---+---+---+ + Alternatively, the user can define a function that takes three arguments. In this case, + the grouping key(s) will be passed as the first argument and the data will be passed as the + second and third arguments. The grouping key(s) will be passed as a tuple of numpy data + types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in as two + `pandas.DataFrame` containing all columns from the original Spark DataFrames. + + >>> @pandas_udf("time int, id int, v1 double, v2 string", + ... PandasUDFType.COGROUPED_MAP) # doctest: +SKIP + ... def asof_join(k, l, r): + ... if k == (1,): + ... return pd.merge_asof(l, r, on="time", by="id") + ... else: + ... return pd.DataFrame(columns=['time', 'id', 'v1', 'v2']) + >>> df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show() # doctest: +SKIP + +--------+---+---+---+ + | time| id| v1| v2| + +--------+---+---+---+ + |20000101| 1|1.0| x| + |20000102| 1|3.0| x| + +--------+---+---+---+ + .. seealso:: :meth:`pyspark.sql.functions.pandas_udf` """ @@ -96,3 +119,25 @@ def apply(self, udf): def _extract_cols(gd): df = gd._df return [df[col] for col in df.columns] + + +def _test(): + import doctest + from pyspark.sql import SparkSession + import pyspark.sql.cogroup + globs = pyspark.sql.cogroup.__dict__.copy() + spark = SparkSession.builder\ + .master("local[4]")\ + .appName("sql.cogroup tests")\ + .getOrCreate() + globs['spark'] = spark + (failure_count, test_count) = doctest.testmod( + pyspark.sql.cogroup, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) + spark.stop() + if failure_count: + sys.exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py b/python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py index 7f3f7fa3168a..bc2265fc5fe1 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py @@ -32,14 +32,9 @@ import pyarrow as pa -""" -Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names -from kwargs w/ Python 2, so need to set check_column_type=False and avoid this check -""" -if sys.version < '3': - _check_column_type = False -else: - _check_column_type = True +# Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names +# From kwargs w/ Python 2, so need to set check_column_type=False and avoid this check +_check_column_type = sys.version >= '3' @unittest.skipIf( diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py index adbe2d103ade..8918d5ac0cdd 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py @@ -37,14 +37,9 @@ import pyarrow as pa -""" -Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names -from kwargs w/ Python 2, so need to set check_column_type=False and avoid this check -""" -if sys.version < '3': - _check_column_type = False -else: - _check_column_type = True +# Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names +# from kwargs w/ Python 2, so need to set check_column_type=False and avoid this check +_check_column_type = sys.version >= '3' @unittest.skipIf( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index f6d13be0e89b..4d4731870700 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -47,8 +47,8 @@ import org.apache.spark.sql.types.{NumericType, StructType} */ @Stable class RelationalGroupedDataset protected[sql]( - val df: DataFrame, - val groupingExprs: Seq[Expression], + private[sql] val df: DataFrame, + private[sql] val groupingExprs: Seq[Expression], groupType: RelationalGroupedDataset.GroupType) { private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = {