Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/pyspark/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,12 @@
from pyspark.sql.group import GroupedData
from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter
from pyspark.sql.window import Window, WindowSpec
from pyspark.sql.cogroup import CoGroupedData


__all__ = [
'SparkSession', 'SQLContext', 'UDFRegistration',
'DataFrame', 'GroupedData', 'Column', 'Catalog', 'Row',
'DataFrameNaFunctions', 'DataFrameStatFunctions', 'Window', 'WindowSpec',
'DataFrameReader', 'DataFrameWriter'
'DataFrameReader', 'DataFrameWriter', 'CoGroupedData'
]
60 changes: 51 additions & 9 deletions python/pyspark/sql/cogroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys

from pyspark import since
from pyspark.rdd import PythonEvalType
Expand Down Expand Up @@ -43,9 +44,13 @@ def apply(self, udf):
as a `DataFrame`.

The user-defined function should take two `pandas.DataFrame` and return another
`pandas.DataFrame`. For each side of the cogroup, all columns are passed together
as a `pandas.DataFrame` to the user-function and the returned `pandas.DataFrame`
are combined as a :class:`DataFrame`.
`pandas.DataFrame`. Alternatively, a user-defined function which additionally takes
a Tuple can be provided, in which case the cogroup key will be passed in as the Tuple
parameter.

For each side of the cogroup, all columns are passed together as a `pandas.DataFrame`
to the user-function and the returned `pandas.DataFrame` are combined as a
:class:`DataFrame`.

The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the
returnType of the pandas udf.
Expand All @@ -61,15 +66,15 @@ def apply(self, udf):

>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> df1 = spark.createDataFrame(
... [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)],
... ("time", "id", "v1"))
... [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)],
... ("time", "id", "v1"))
>>> df2 = spark.createDataFrame(
... [(20000101, 1, "x"), (20000101, 2, "y")],
... ("time", "id", "v2"))
>>> @pandas_udf("time int, id int, v1 double, v2 string", PandasUDFType.COGROUPED_MAP)
... [(20000101, 1, "x"), (20000101, 2, "y")],
... ("time", "id", "v2"))
>>> @pandas_udf("time int, id int, v1 double, v2 string", PandasUDFType.COGROUPED_MAP) # doctest: +SKIP
... def asof_join(l, r):
... return pd.merge_asof(l, r, on="time", by="id")
>>> df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show()
>>> df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show() # doctest: +SKIP
+--------+---+---+---+
| time| id| v1| v2|
+--------+---+---+---+
Expand All @@ -78,6 +83,19 @@ def apply(self, udf):
|20000101| 2|2.0| y|
|20000102| 2|4.0| y|
+--------+---+---+---+
>>> @pandas_udf("time int, id int, v1 double, v2 string", PandasUDFType.COGROUPED_MAP) # doctest: +SKIP
... def asof_join(k, l, r):
... if k == (1,):
... return pd.merge_asof(l, r, on="time", by="id")
... else:
... return pd.DataFrame(columns=['time', 'id', 'v1', 'v2'])
>>> df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show() # doctest: +SKIP
+--------+---+---+---+
| time| id| v1| v2|
+--------+---+---+---+
|20000101| 1|1.0| x|
|20000102| 1|3.0| x|
+--------+---+---+---+

.. seealso:: :meth:`pyspark.sql.functions.pandas_udf`

Expand All @@ -96,3 +114,27 @@ def apply(self, udf):
def _extract_cols(gd):
df = gd._df
return [df[col] for col in df.columns]


def _test():
import doctest
from pyspark.sql import SparkSession
import pyspark.sql.cogroup
globs = pyspark.sql.cogroup.__dict__.copy()
spark = SparkSession.builder\
.master("local[4]")\
.appName("sql.cogroup tests")\
.getOrCreate()
sc = spark.sparkContext
globs['sc'] = sc
globs['spark'] = spark
(failure_count, test_count) = doctest.testmod(
pyspark.sql.cogroup, globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
spark.stop()
if failure_count:
sys.exit(-1)


if __name__ == "__main__":
_test()
12 changes: 3 additions & 9 deletions python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,9 @@
import pyarrow as pa


"""
Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names
from kwargs w/ Python 2, so need to set check_column_type=False and avoid this check
"""
if sys.version < '3':
_check_column_type = False
else:
_check_column_type = True

# Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names
# From kwargs w/ Python 2, so need to set check_column_type=False and avoid this check
_check_column_type = sys.version >= '3'

@unittest.skipIf(
not have_pandas or not have_pyarrow,
Expand Down
11 changes: 3 additions & 8 deletions python/pyspark/sql/tests/test_pandas_udf_grouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,9 @@
import pyarrow as pa


"""
Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names
from kwargs w/ Python 2, so need to set check_column_type=False and avoid this check
"""
if sys.version < '3':
_check_column_type = False
else:
_check_column_type = True
# Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names
# from kwargs w/ Python 2, so need to set check_column_type=False and avoid this check
_check_column_type = sys.version >= '3'


@unittest.skipIf(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ import org.apache.spark.sql.types.{NumericType, StructType}
*/
@Stable
class RelationalGroupedDataset protected[sql](
val df: DataFrame,
val groupingExprs: Seq[Expression],
private [sql] val df: DataFrame,
private [sql] val groupingExprs: Seq[Expression],
groupType: RelationalGroupedDataset.GroupType) {

private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = {
Expand Down