diff --git a/python/pyspark/sql/tests/test_pandas_grouped_map.py b/python/pyspark/sql/tests/test_pandas_grouped_map.py index 76119432662ba..cc6167e619285 100644 --- a/python/pyspark/sql/tests/test_pandas_grouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_grouped_map.py @@ -545,13 +545,13 @@ def f(pdf): def test_grouped_over_window_with_key(self): - data = [(0, 1, "2018-03-10T00:00:00+00:00", False), - (1, 2, "2018-03-11T00:00:00+00:00", False), - (2, 2, "2018-03-12T00:00:00+00:00", False), - (3, 3, "2018-03-15T00:00:00+00:00", False), - (4, 3, "2018-03-16T00:00:00+00:00", False), - (5, 3, "2018-03-17T00:00:00+00:00", False), - (6, 3, "2018-03-21T00:00:00+00:00", False)] + data = [(0, 1, "2018-03-10T00:00:00+00:00", [0]), + (1, 2, "2018-03-11T00:00:00+00:00", [0]), + (2, 2, "2018-03-12T00:00:00+00:00", [0]), + (3, 3, "2018-03-15T00:00:00+00:00", [0]), + (4, 3, "2018-03-16T00:00:00+00:00", [0]), + (5, 3, "2018-03-17T00:00:00+00:00", [0]), + (6, 3, "2018-03-21T00:00:00+00:00", [0])] expected_window = [ {'start': datetime.datetime(2018, 3, 10, 0, 0), @@ -562,30 +562,43 @@ def test_grouped_over_window_with_key(self): 'end': datetime.datetime(2018, 3, 25, 0, 0)}, ] - expected = {0: (1, expected_window[0]), - 1: (2, expected_window[0]), - 2: (2, expected_window[0]), - 3: (3, expected_window[1]), - 4: (3, expected_window[1]), - 5: (3, expected_window[1]), - 6: (3, expected_window[2])} + expected_key = {0: (1, expected_window[0]), + 1: (2, expected_window[0]), + 2: (2, expected_window[0]), + 3: (3, expected_window[1]), + 4: (3, expected_window[1]), + 5: (3, expected_window[1]), + 6: (3, expected_window[2])} + + # id -> array of group with len of num records in window + expected = {0: [1], + 1: [2, 2], + 2: [2, 2], + 3: [3, 3, 3], + 4: [3, 3, 3], + 5: [3, 3, 3], + 6: [3]} df = self.spark.createDataFrame(data, ['id', 'group', 'ts', 'result']) df = df.select(col('id'), col('group'), col('ts').cast('timestamp'), col('result')) - @pandas_udf(df.schema, PandasUDFType.GROUPED_MAP) def f(key, pdf): group = key[0] window_range = key[1] - # Result will be True if group and window range equal to expected - is_expected = pdf.id.apply(lambda id: (expected[id][0] == group and - expected[id][1] == window_range)) - return pdf.assign(result=is_expected) - result = df.groupby('group', window('ts', '5 days')).apply(f).select('result').collect() + # Make sure the key with group and window values are correct + for _, i in pdf.id.iteritems(): + assert expected_key[i][0] == group, "{} != {}".format(expected_key[i][0], group) + assert expected_key[i][1] == window_range, \ + "{} != {}".format(expected_key[i][1], window_range) - # Check that all group and window_range values from udf matched expected - self.assertTrue(all([r[0] for r in result])) + return pdf.assign(result=[[group] * len(pdf)] * len(pdf)) + + result = df.groupby('group', window('ts', '5 days')).applyInPandas(f, df.schema)\ + .select('id', 'result').collect() + + for r in result: + self.assertListEqual(expected[r[0]], r[1]) def test_case_insensitive_grouping_column(self): # SPARK-31915: case-insensitive grouping column should work.