Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ private[spark] object PythonEvalType {
val SQL_GROUPED_AGG_PANDAS_UDF = 202
val SQL_WINDOW_AGG_PANDAS_UDF = 203
val SQL_SCALAR_PANDAS_ITER_UDF = 204
val SQL_COGROUPED_MAP_PANDAS_UDF = 205

def toString(pythonEvalType: Int): String = pythonEvalType match {
case NON_UDF => "NON_UDF"
Expand All @@ -56,6 +57,7 @@ private[spark] object PythonEvalType {
case SQL_GROUPED_AGG_PANDAS_UDF => "SQL_GROUPED_AGG_PANDAS_UDF"
case SQL_WINDOW_AGG_PANDAS_UDF => "SQL_WINDOW_AGG_PANDAS_UDF"
case SQL_SCALAR_PANDAS_ITER_UDF => "SQL_SCALAR_PANDAS_ITER_UDF"
case SQL_COGROUPED_MAP_PANDAS_UDF => "SQL_COGROUPED_MAP_PANDAS_UDF"
}
}

Expand Down
1 change: 1 addition & 0 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class PythonEvalType(object):
SQL_GROUPED_AGG_PANDAS_UDF = 202
SQL_WINDOW_AGG_PANDAS_UDF = 203
SQL_SCALAR_PANDAS_ITER_UDF = 204
SQL_COGROUPED_MAP_PANDAS_UDF = 205


def portable_hash(x):
Expand Down
37 changes: 37 additions & 0 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,27 @@ def __repr__(self):
return "ArrowStreamPandasSerializer"


class InterleavedArrowReader(object):

def __init__(self, stream):
import pyarrow as pa
self._schema1 = pa.read_schema(stream)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to read these also using the message reader but for some reason pa.read_schema(self_reader.read_next_message()) didn't work.

self._schema2 = pa.read_schema(stream)
self._reader = pa.MessageReader.open_stream(stream)

def __iter__(self):
return self

def __next__(self):
import pyarrow as pa
batch1 = pa.read_record_batch(self._reader.read_next_message(), self._schema1)
batch2 = pa.read_record_batch(self._reader.read_next_message(), self._schema2)
return batch1, batch2

def next(self):
return self.__next__()


class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
"""
Serializer used by Python worker to evaluate Pandas UDFs
Expand Down Expand Up @@ -401,6 +422,22 @@ def __repr__(self):
return "ArrowStreamPandasUDFSerializer"


class InterleavedArrowStreamPandasSerializer(ArrowStreamPandasUDFSerializer):

def __init__(self, timezone, safecheck, assign_cols_by_name):
super(InterleavedArrowStreamPandasSerializer, self).__init__(timezone, safecheck, assign_cols_by_name)

def load_stream(self, stream):
"""
Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series.
"""
import pyarrow as pa
reader = InterleavedArrowReader(pa.input_stream(stream))
for batch1, batch2 in reader:
yield ( [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch1]).itercolumns()],
[self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch2]).itercolumns()])


class BatchedSerializer(Serializer):

"""
Expand Down
38 changes: 38 additions & 0 deletions python/pyspark/sql/cogroup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from pyspark.sql.dataframe import DataFrame


class CoGroupedData(object):

def __init__(self, gd1, gd2):
self._gd1 = gd1
self._gd2 = gd2
self.sql_ctx = gd1.sql_ctx

def apply(self, udf):
all_cols = self._extract_cols(self._gd1) + self._extract_cols(self._gd2)
udf_column = udf(*all_cols)
jdf = self._gd1._jgd.flatMapCoGroupsInPandas(self._gd2._jgd, udf_column._jc.expr())
return DataFrame(jdf, self.sql_ctx)

@staticmethod
def _extract_cols(gd):
df = gd._df
return [df[col] for col in df.columns]

3 changes: 3 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2800,6 +2800,8 @@ class PandasUDFType(object):

GROUPED_MAP = PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF

COGROUPED_MAP = PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF

GROUPED_AGG = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF


Expand Down Expand Up @@ -3178,6 +3180,7 @@ def pandas_udf(f=None, returnType=None, functionType=None):
if eval_type not in [PythonEvalType.SQL_SCALAR_PANDAS_UDF,
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF]:
raise ValueError("Invalid functionType: "
"functionType must be one the values from PandasUDFType")
Expand Down
4 changes: 4 additions & 0 deletions python/pyspark/sql/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from pyspark.sql.column import Column, _to_seq
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import *
from pyspark.sql.cogroup import CoGroupedData

__all__ = ["GroupedData"]

Expand Down Expand Up @@ -220,6 +221,9 @@ def pivot(self, pivot_col, values=None):
jgd = self._jgd.pivot(pivot_col, values)
return GroupedData(jgd, self._df)

def cogroup(self, other):
return CoGroupedData(self, other)

@since(2.3)
def apply(self, udf):
"""
Expand Down
101 changes: 101 additions & 0 deletions python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import datetime
import unittest
import sys

from collections import OrderedDict
from decimal import Decimal

from pyspark.sql import Row
from pyspark.sql.functions import array, explode, col, lit, udf, sum, pandas_udf, PandasUDFType
from pyspark.sql.types import *
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
pandas_requirement_message, pyarrow_requirement_message
from pyspark.testing.utils import QuietTest

if have_pandas:
import pandas as pd
from pandas.util.testing import assert_frame_equal

if have_pyarrow:
import pyarrow as pa


"""
Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names
from kwargs w/ Python 2, so need to set check_column_type=False and avoid this check
"""
if sys.version < '3':
_check_column_type = False
else:
_check_column_type = True


@unittest.skipIf(
not have_pandas or not have_pyarrow,
pandas_requirement_message or pyarrow_requirement_message)
class CoGroupedMapPandasUDFTests(ReusedSQLTestCase):

@property
def data1(self):
return self.spark.range(10).toDF('id') \
.withColumn("ks", array([lit(i) for i in range(20, 30)])) \
.withColumn("k", explode(col('ks')))\
.withColumn("v", col('k') * 10)\
.drop('ks')

@property
def data2(self):
return self.spark.range(10).toDF('id') \
.withColumn("ks", array([lit(i) for i in range(20, 30)])) \
.withColumn("k", explode(col('ks'))) \
.withColumn("v2", col('k') * 100) \
.drop('ks')

def test_simple(self):
import pandas as pd

l = self.data1
r = self.data2

@pandas_udf('id long, k int, v int, v2 int', PandasUDFType.COGROUPED_MAP)
def merge_pandas(left, right):
return pd.merge(left, right, how='outer', on=['k', 'id'])

result = l\
.groupby('id')\
.cogroup(r.groupby(r.id))\
.apply(merge_pandas)\
.sort(['id', 'k'])\
.toPandas()

expected = pd\
.merge(l.toPandas(), r.toPandas(), how='outer', on=['k', 'id'])

assert_frame_equal(expected, result, check_column_type=_check_column_type)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @d80tb7, I work with Li and am also interested in cogroup.

Can I ask how you were able to get your test to run? I wasn't able to run it without the following snippet:

if __name__ == "__main__":
    from pyspark.sql.tests.test_pandas_udf_grouped_map import *

    try:
        import xmlrunner
        testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
    except ImportError:
        testRunner = None
    unittest.main(testRunner=testRunner, verbosity=2)

taken from the other similar tests like test_pandas_udf_grouped_map.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @hjoo

So far I've just been running it via PyCharm's unit test runner under python 3. I suspect the problem you had was that the iterator I added wasn't compatible with python 2 (sorry!). I've fixed the iterator and added a similar snippet to the one you provided above. Now I can run using python/run-tests --testnames pyspark.sql.tests.test_pandas_udf_cogrouped_map

If you still have problems let me know the error you get and I'll take a look.

if __name__ == "__main__":
from pyspark.sql.tests.test_pandas_udf_cogrouped_map import *

try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
41 changes: 36 additions & 5 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from pyspark.rdd import PythonEvalType
from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
BatchedSerializer, ArrowStreamPandasUDFSerializer
BatchedSerializer, ArrowStreamPandasUDFSerializer, InterleavedArrowStreamPandasSerializer
from pyspark.sql.types import to_arrow_type, StructType
from pyspark.util import _get_argspec, fail_on_stopiteration
from pyspark import shuffle
Expand Down Expand Up @@ -111,8 +111,25 @@ def verify_result_length(result, length):
map(verify_result_type, f(*iterator)))


def wrap_grouped_map_pandas_udf(f, return_type, argspec):
def wrap_cogrouped_map_pandas_udf(f, return_type):

def wrapped(left, right):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is left a list of pd.Series here? Probably name them left_series and value_series to be more readable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes they are- they are value series for left and right sides of the cogroup respectively. Agreed that the names aren't the best. I'll improve them when I do a tidy up.

import pandas as pd
result = f(pd.concat(left, axis=1), pd.concat(right, axis=1))
if not isinstance(result, pd.DataFrame):
raise TypeError("Return type of the user-defined function should be "
"pandas.DataFrame, but is {}".format(type(result)))
if not len(result.columns) == len(return_type):
raise RuntimeError(
"Number of columns of the returned pandas.DataFrame "
"doesn't match specified schema. "
"Expected: {} Actual: {}".format(len(return_type), len(result.columns)))
return result

return lambda v: [(wrapped(v[0], v[1]), to_arrow_type(return_type))]


def wrap_grouped_map_pandas_udf(f, return_type, argspec):
def wrapped(key_series, value_series):
import pandas as pd

Expand Down Expand Up @@ -232,6 +249,8 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
argspec = _get_argspec(chained_func) # signature was lost when wrapping it
return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec)
elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
return arg_offsets, wrap_cogrouped_map_pandas_udf(func, return_type)
elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type)
elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF:
Expand All @@ -246,6 +265,7 @@ def read_udfs(pickleSer, infile, eval_type):
runner_conf = {}

if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
Expand All @@ -269,10 +289,13 @@ def read_udfs(pickleSer, infile, eval_type):

# Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of
# pandas Series. See SPARK-27240.
df_for_struct = (eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF or
if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
ser = InterleavedArrowStreamPandasSerializer(timezone, safecheck, assign_cols_by_name)
else:
df_for_struct = (eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF or
eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF)
ser = ArrowStreamPandasUDFSerializer(timezone, safecheck, assign_cols_by_name,
df_for_struct)
ser = ArrowStreamPandasUDFSerializer(timezone, safecheck, assign_cols_by_name,
df_for_struct)
else:
ser = BatchedSerializer(PickleSerializer(), 100)

Expand Down Expand Up @@ -343,6 +366,14 @@ def map_batch(batch):
arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]]
arg1 = ["a[%d]" % o for o in arg_offsets[split_offset:]]
mapper_str = "lambda a: f([%s], [%s])" % (", ".join(arg0), ", ".join(arg1))
elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
# We assume there is only one UDF here because cogrouped map doesn't
# support combining multiple UDFs.
assert num_udfs == 1
arg_offsets, udf = read_single_udf(
pickleSer, infile, eval_type, runner_conf, udf_index=0)
udfs['f'] = udf
mapper_str = "lambda a: f(a)"
else:
# Create function like this:
# lambda a: (f0(a[0]), f1(a[1], a[2]), f2(a[3]))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,10 @@ class Analyzer(
// To resolve duplicate expression IDs for Join and Intersect
case j @ Join(left, right, _, _, _) if !j.duplicateResolved =>
j.copy(right = dedupRight(left, right))
case f @ FlatMapCoGroupsInPandas(leftAttributes, rightAttributes, _, _, left, right) =>
val leftAttributes2 = leftAttributes.map(x => resolveExpressionBottomUp(x, left).asInstanceOf[Attribute])
val rightAttributes2 = rightAttributes.map(x => resolveExpressionBottomUp(x, right).asInstanceOf[Attribute])
f.copy(leftAttributes=leftAttributes2, rightAttributes=rightAttributes2)
case i @ Intersect(left, right, _) if !i.duplicateResolved =>
i.copy(right = dedupRight(left, right))
case e @ Except(left, right, _) if !e.duplicateResolved =>
Expand Down Expand Up @@ -2269,6 +2273,7 @@ class Analyzer(
}
}


/**
* Removes natural or using joins by calculating output columns based on output from two sides,
* Then apply a Project on a normal Join to eliminate natural or using join.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,18 @@ case class FlatMapGroupsInPandas(
override val producedAttributes = AttributeSet(output)
}


case class FlatMapCoGroupsInPandas(
leftAttributes: Seq[Attribute],
rightAttributes: Seq[Attribute],
functionExpr: Expression,
output: Seq[Attribute],
left: LogicalPlan,
right: LogicalPlan) extends BinaryNode {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

override val producedAttributes = AttributeSet(output)
}


trait BaseEvalPython extends UnaryNode {

def udfs: Seq[PythonUDF]
Expand Down
Loading