Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 52 additions & 48 deletions python/pyspark/pandas/tests/test_expanding.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,37 +26,37 @@


class ExpandingTest(PandasOnSparkTestCase, TestUtils):
def _test_expanding_func(self, f):
def _test_expanding_func(self, ps_func, pd_func=None):
if not pd_func:
pd_func = ps_func
if isinstance(pd_func, str):
pd_func = self.convert_str_to_lambda(pd_func)
if isinstance(ps_func, str):
ps_func = self.convert_str_to_lambda(ps_func)
pser = pd.Series([1, 2, 3, 7, 9, 8], index=np.random.rand(6), name="a")
psser = ps.from_pandas(pser)
self.assert_eq(
getattr(psser.expanding(2), f)(), getattr(pser.expanding(2), f)(), almost=True
)
self.assert_eq(
getattr(psser.expanding(2), f)().sum(),
getattr(pser.expanding(2), f)().sum(),
almost=True,
)
self.assert_eq(ps_func(psser.expanding(2)), pd_func(pser.expanding(2)), almost=True)
self.assert_eq(ps_func(psser.expanding(2)), pd_func(pser.expanding(2)), almost=True)

# Multiindex
pser = pd.Series(
[1, 2, 3], index=pd.MultiIndex.from_tuples([("a", "x"), ("a", "y"), ("b", "z")])
)
psser = ps.from_pandas(pser)
self.assert_eq(getattr(psser.expanding(2), f)(), getattr(pser.expanding(2), f)())
self.assert_eq(ps_func(psser.expanding(2)), pd_func(pser.expanding(2)))

pdf = pd.DataFrame(
{"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]}, index=np.random.rand(4)
)
psdf = ps.from_pandas(pdf)
self.assert_eq(getattr(psdf.expanding(2), f)(), getattr(pdf.expanding(2), f)())
self.assert_eq(getattr(psdf.expanding(2), f)().sum(), getattr(pdf.expanding(2), f)().sum())
self.assert_eq(ps_func(psdf.expanding(2)), pd_func(pdf.expanding(2)))
self.assert_eq(ps_func(psdf.expanding(2)).sum(), pd_func(pdf.expanding(2)).sum())

# Multiindex column
columns = pd.MultiIndex.from_tuples([("a", "x"), ("a", "y")])
pdf.columns = columns
psdf.columns = columns
self.assert_eq(getattr(psdf.expanding(2), f)(), getattr(pdf.expanding(2), f)())
self.assert_eq(ps_func(psdf.expanding(2)), pd_func(pdf.expanding(2)))

def test_expanding_error(self):
with self.assertRaisesRegex(ValueError, "min_periods must be >= 0"):
Expand Down Expand Up @@ -97,16 +97,22 @@ def test_expanding_skew(self):
def test_expanding_kurt(self):
self._test_expanding_func("kurt")

def _test_groupby_expanding_func(self, f):
def _test_groupby_expanding_func(self, ps_func, pd_func=None):
if not pd_func:
pd_func = ps_func
Comment on lines +101 to +102
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I got a bit confusion for this part.

Could you tell one example that considers this case?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's that it uses the same function if pd_func is omitted.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, just like: self._test_groupby_rolling_func("count").

We add ps_func and pd_func separately for some case like: self._test_groupby_expanding_func(lambda x:x.quantile(0.5), lambda x: x.quantile(0.5, "lower")), pd_func and ps_func is different.

if isinstance(pd_func, str):
pd_func = self.convert_str_to_lambda(pd_func)
if isinstance(ps_func, str):
ps_func = self.convert_str_to_lambda(ps_func)
pser = pd.Series([1, 2, 3, 2], index=np.random.rand(4), name="a")
psser = ps.from_pandas(pser)
self.assert_eq(
getattr(psser.groupby(psser).expanding(2), f)().sort_index(),
getattr(pser.groupby(pser).expanding(2), f)().sort_index(),
ps_func(psser.groupby(psser).expanding(2)).sort_index(),
pd_func(pser.groupby(pser).expanding(2)).sort_index(),
)
self.assert_eq(
getattr(psser.groupby(psser).expanding(2), f)().sum(),
getattr(pser.groupby(pser).expanding(2), f)().sum(),
ps_func(psser.groupby(psser).expanding(2)).sum(),
pd_func(pser.groupby(pser).expanding(2)).sum(),
)

# Multiindex
Expand All @@ -117,8 +123,8 @@ def _test_groupby_expanding_func(self, f):
)
psser = ps.from_pandas(pser)
self.assert_eq(
getattr(psser.groupby(psser).expanding(2), f)().sort_index(),
getattr(pser.groupby(pser).expanding(2), f)().sort_index(),
ps_func(psser.groupby(psser).expanding(2)).sort_index(),
pd_func(pser.groupby(pser).expanding(2)).sort_index(),
)

pdf = pd.DataFrame({"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]})
Expand All @@ -127,42 +133,42 @@ def _test_groupby_expanding_func(self, f):
# The behavior of GroupBy.expanding is changed from pandas 1.3.
if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
self.assert_eq(
getattr(psdf.groupby(psdf.a).expanding(2), f)().sort_index(),
getattr(pdf.groupby(pdf.a).expanding(2), f)().sort_index(),
ps_func(psdf.groupby(psdf.a).expanding(2)).sort_index(),
pd_func(pdf.groupby(pdf.a).expanding(2)).sort_index(),
)
self.assert_eq(
getattr(psdf.groupby(psdf.a).expanding(2), f)().sum(),
getattr(pdf.groupby(pdf.a).expanding(2), f)().sum(),
ps_func(psdf.groupby(psdf.a).expanding(2)).sum(),
pd_func(pdf.groupby(pdf.a).expanding(2)).sum(),
)
self.assert_eq(
getattr(psdf.groupby(psdf.a + 1).expanding(2), f)().sort_index(),
getattr(pdf.groupby(pdf.a + 1).expanding(2), f)().sort_index(),
ps_func(psdf.groupby(psdf.a + 1).expanding(2)).sort_index(),
pd_func(pdf.groupby(pdf.a + 1).expanding(2)).sort_index(),
)
else:
self.assert_eq(
getattr(psdf.groupby(psdf.a).expanding(2), f)().sort_index(),
getattr(pdf.groupby(pdf.a).expanding(2), f)().drop("a", axis=1).sort_index(),
ps_func(psdf.groupby(psdf.a).expanding(2)).sort_index(),
pd_func(pdf.groupby(pdf.a).expanding(2)).drop("a", axis=1).sort_index(),
)
self.assert_eq(
getattr(psdf.groupby(psdf.a).expanding(2), f)().sum(),
getattr(pdf.groupby(pdf.a).expanding(2), f)().sum().drop("a"),
ps_func(psdf.groupby(psdf.a).expanding(2)).sum(),
pd_func(pdf.groupby(pdf.a).expanding(2)).sum().drop("a"),
)
self.assert_eq(
getattr(psdf.groupby(psdf.a + 1).expanding(2), f)().sort_index(),
getattr(pdf.groupby(pdf.a + 1).expanding(2), f)().drop("a", axis=1).sort_index(),
ps_func(psdf.groupby(psdf.a + 1).expanding(2)).sort_index(),
pd_func(pdf.groupby(pdf.a + 1).expanding(2)).drop("a", axis=1).sort_index(),
)

self.assert_eq(
getattr(psdf.b.groupby(psdf.a).expanding(2), f)().sort_index(),
getattr(pdf.b.groupby(pdf.a).expanding(2), f)().sort_index(),
ps_func(psdf.b.groupby(psdf.a).expanding(2)).sort_index(),
pd_func(pdf.b.groupby(pdf.a).expanding(2)).sort_index(),
)
self.assert_eq(
getattr(psdf.groupby(psdf.a)["b"].expanding(2), f)().sort_index(),
getattr(pdf.groupby(pdf.a)["b"].expanding(2), f)().sort_index(),
ps_func(psdf.groupby(psdf.a)["b"].expanding(2)).sort_index(),
pd_func(pdf.groupby(pdf.a)["b"].expanding(2)).sort_index(),
)
self.assert_eq(
getattr(psdf.groupby(psdf.a)[["b"]].expanding(2), f)().sort_index(),
getattr(pdf.groupby(pdf.a)[["b"]].expanding(2), f)().sort_index(),
ps_func(psdf.groupby(psdf.a)[["b"]].expanding(2)).sort_index(),
pd_func(pdf.groupby(pdf.a)[["b"]].expanding(2)).sort_index(),
)

# Multiindex column
Expand All @@ -173,25 +179,23 @@ def _test_groupby_expanding_func(self, f):
# The behavior of GroupBy.expanding is changed from pandas 1.3.
if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
self.assert_eq(
getattr(psdf.groupby(("a", "x")).expanding(2), f)().sort_index(),
getattr(pdf.groupby(("a", "x")).expanding(2), f)().sort_index(),
ps_func(psdf.groupby(("a", "x")).expanding(2)).sort_index(),
pd_func(pdf.groupby(("a", "x")).expanding(2)).sort_index(),
)

self.assert_eq(
getattr(psdf.groupby([("a", "x"), ("a", "y")]).expanding(2), f)().sort_index(),
getattr(pdf.groupby([("a", "x"), ("a", "y")]).expanding(2), f)().sort_index(),
ps_func(psdf.groupby([("a", "x"), ("a", "y")]).expanding(2)).sort_index(),
pd_func(pdf.groupby([("a", "x"), ("a", "y")]).expanding(2)).sort_index(),
)
else:
self.assert_eq(
getattr(psdf.groupby(("a", "x")).expanding(2), f)().sort_index(),
getattr(pdf.groupby(("a", "x")).expanding(2), f)()
.drop(("a", "x"), axis=1)
.sort_index(),
ps_func(psdf.groupby(("a", "x")).expanding(2)).sort_index(),
pd_func(pdf.groupby(("a", "x")).expanding(2)).drop(("a", "x"), axis=1).sort_index(),
)

self.assert_eq(
getattr(psdf.groupby([("a", "x"), ("a", "y")]).expanding(2), f)().sort_index(),
getattr(pdf.groupby([("a", "x"), ("a", "y")]).expanding(2), f)()
ps_func(psdf.groupby([("a", "x"), ("a", "y")]).expanding(2)).sort_index(),
pd_func(pdf.groupby([("a", "x"), ("a", "y")]).expanding(2))
.drop([("a", "x"), ("a", "y")], axis=1)
.sort_index(),
)
Expand Down
94 changes: 52 additions & 42 deletions python/pyspark/pandas/tests/test_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,17 @@ def test_rolling_error(self):
):
Rolling(1, 2)

def _test_rolling_func(self, f):
def _test_rolling_func(self, ps_func, pd_func=None):
if not pd_func:
pd_func = ps_func
if isinstance(pd_func, str):
pd_func = self.convert_str_to_lambda(pd_func)
if isinstance(ps_func, str):
ps_func = self.convert_str_to_lambda(ps_func)
pser = pd.Series([1, 2, 3, 7, 9, 8], index=np.random.rand(6), name="a")
psser = ps.from_pandas(pser)
self.assert_eq(getattr(psser.rolling(2), f)(), getattr(pser.rolling(2), f)())
self.assert_eq(getattr(psser.rolling(2), f)().sum(), getattr(pser.rolling(2), f)().sum())
self.assert_eq(ps_func(psser.rolling(2)), pd_func(pser.rolling(2)))
self.assert_eq(ps_func(psser.rolling(2)).sum(), pd_func(pser.rolling(2)).sum())

# Multiindex
pser = pd.Series(
Expand All @@ -49,20 +55,20 @@ def _test_rolling_func(self, f):
name="a",
)
psser = ps.from_pandas(pser)
self.assert_eq(getattr(psser.rolling(2), f)(), getattr(pser.rolling(2), f)())
self.assert_eq(ps_func(psser.rolling(2)), pd_func(pser.rolling(2)))

pdf = pd.DataFrame(
{"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]}, index=np.random.rand(4)
)
psdf = ps.from_pandas(pdf)
self.assert_eq(getattr(psdf.rolling(2), f)(), getattr(pdf.rolling(2), f)())
self.assert_eq(getattr(psdf.rolling(2), f)().sum(), getattr(pdf.rolling(2), f)().sum())
self.assert_eq(ps_func(psdf.rolling(2)), pd_func(pdf.rolling(2)))
self.assert_eq(ps_func(psdf.rolling(2)).sum(), pd_func(pdf.rolling(2)).sum())

# Multiindex column
columns = pd.MultiIndex.from_tuples([("a", "x"), ("a", "y")])
pdf.columns = columns
psdf.columns = columns
self.assert_eq(getattr(psdf.rolling(2), f)(), getattr(pdf.rolling(2), f)())
self.assert_eq(ps_func(psdf.rolling(2)), pd_func(pdf.rolling(2)))

def test_rolling_min(self):
self._test_rolling_func("min")
Expand Down Expand Up @@ -91,16 +97,22 @@ def test_rolling_skew(self):
def test_rolling_kurt(self):
self._test_rolling_func("kurt")

def _test_groupby_rolling_func(self, f):
def _test_groupby_rolling_func(self, ps_func, pd_func=None):
if not pd_func:
pd_func = ps_func
if isinstance(pd_func, str):
pd_func = self.convert_str_to_lambda(pd_func)
if isinstance(ps_func, str):
ps_func = self.convert_str_to_lambda(ps_func)
pser = pd.Series([1, 2, 3, 2], index=np.random.rand(4), name="a")
psser = ps.from_pandas(pser)
self.assert_eq(
getattr(psser.groupby(psser).rolling(2), f)().sort_index(),
getattr(pser.groupby(pser).rolling(2), f)().sort_index(),
ps_func(psser.groupby(psser).rolling(2)).sort_index(),
pd_func(pser.groupby(pser).rolling(2)).sort_index(),
)
self.assert_eq(
getattr(psser.groupby(psser).rolling(2), f)().sum(),
getattr(pser.groupby(pser).rolling(2), f)().sum(),
ps_func(psser.groupby(psser).rolling(2)).sum(),
pd_func(pser.groupby(pser).rolling(2)).sum(),
)

# Multiindex
Expand All @@ -111,8 +123,8 @@ def _test_groupby_rolling_func(self, f):
)
psser = ps.from_pandas(pser)
self.assert_eq(
getattr(psser.groupby(psser).rolling(2), f)().sort_index(),
getattr(pser.groupby(pser).rolling(2), f)().sort_index(),
ps_func(psser.groupby(psser).rolling(2)).sort_index(),
pd_func(pser.groupby(pser).rolling(2)).sort_index(),
)

pdf = pd.DataFrame({"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]})
Expand All @@ -121,42 +133,42 @@ def _test_groupby_rolling_func(self, f):
# The behavior of GroupBy.rolling is changed from pandas 1.3.
if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
self.assert_eq(
getattr(psdf.groupby(psdf.a).rolling(2), f)().sort_index(),
getattr(pdf.groupby(pdf.a).rolling(2), f)().sort_index(),
ps_func(psdf.groupby(psdf.a).rolling(2)).sort_index(),
pd_func(pdf.groupby(pdf.a).rolling(2)).sort_index(),
)
self.assert_eq(
getattr(psdf.groupby(psdf.a).rolling(2), f)().sum(),
getattr(pdf.groupby(pdf.a).rolling(2), f)().sum(),
ps_func(psdf.groupby(psdf.a).rolling(2)).sum(),
pd_func(pdf.groupby(pdf.a).rolling(2)).sum(),
)
self.assert_eq(
getattr(psdf.groupby(psdf.a + 1).rolling(2), f)().sort_index(),
getattr(pdf.groupby(pdf.a + 1).rolling(2), f)().sort_index(),
ps_func(psdf.groupby(psdf.a + 1).rolling(2)).sort_index(),
pd_func(pdf.groupby(pdf.a + 1).rolling(2)).sort_index(),
)
else:
self.assert_eq(
getattr(psdf.groupby(psdf.a).rolling(2), f)().sort_index(),
getattr(pdf.groupby(pdf.a).rolling(2), f)().drop("a", axis=1).sort_index(),
ps_func(psdf.groupby(psdf.a).rolling(2)).sort_index(),
pd_func(pdf.groupby(pdf.a).rolling(2)).drop("a", axis=1).sort_index(),
)
self.assert_eq(
getattr(psdf.groupby(psdf.a).rolling(2), f)().sum(),
getattr(pdf.groupby(pdf.a).rolling(2), f)().sum().drop("a"),
ps_func(psdf.groupby(psdf.a).rolling(2)).sum(),
pd_func(pdf.groupby(pdf.a).rolling(2)).sum().drop("a"),
)
self.assert_eq(
getattr(psdf.groupby(psdf.a + 1).rolling(2), f)().sort_index(),
getattr(pdf.groupby(pdf.a + 1).rolling(2), f)().drop("a", axis=1).sort_index(),
ps_func(psdf.groupby(psdf.a + 1).rolling(2)).sort_index(),
pd_func(pdf.groupby(pdf.a + 1).rolling(2)).drop("a", axis=1).sort_index(),
)

self.assert_eq(
getattr(psdf.b.groupby(psdf.a).rolling(2), f)().sort_index(),
getattr(pdf.b.groupby(pdf.a).rolling(2), f)().sort_index(),
ps_func(psdf.b.groupby(psdf.a).rolling(2)).sort_index(),
pd_func(pdf.b.groupby(pdf.a).rolling(2)).sort_index(),
)
self.assert_eq(
getattr(psdf.groupby(psdf.a)["b"].rolling(2), f)().sort_index(),
getattr(pdf.groupby(pdf.a)["b"].rolling(2), f)().sort_index(),
ps_func(psdf.groupby(psdf.a)["b"].rolling(2)).sort_index(),
pd_func(pdf.groupby(pdf.a)["b"].rolling(2)).sort_index(),
)
self.assert_eq(
getattr(psdf.groupby(psdf.a)[["b"]].rolling(2), f)().sort_index(),
getattr(pdf.groupby(pdf.a)[["b"]].rolling(2), f)().sort_index(),
ps_func(psdf.groupby(psdf.a)[["b"]].rolling(2)).sort_index(),
pd_func(pdf.groupby(pdf.a)[["b"]].rolling(2)).sort_index(),
)

# Multiindex column
Expand All @@ -167,25 +179,23 @@ def _test_groupby_rolling_func(self, f):
# The behavior of GroupBy.rolling is changed from pandas 1.3.
if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
self.assert_eq(
getattr(psdf.groupby(("a", "x")).rolling(2), f)().sort_index(),
getattr(pdf.groupby(("a", "x")).rolling(2), f)().sort_index(),
ps_func(psdf.groupby(("a", "x")).rolling(2)).sort_index(),
pd_func(pdf.groupby(("a", "x")).rolling(2)).sort_index(),
)

self.assert_eq(
getattr(psdf.groupby([("a", "x"), ("a", "y")]).rolling(2), f)().sort_index(),
getattr(pdf.groupby([("a", "x"), ("a", "y")]).rolling(2), f)().sort_index(),
ps_func(psdf.groupby([("a", "x"), ("a", "y")]).rolling(2)).sort_index(),
pd_func(pdf.groupby([("a", "x"), ("a", "y")]).rolling(2)).sort_index(),
)
else:
self.assert_eq(
getattr(psdf.groupby(("a", "x")).rolling(2), f)().sort_index(),
getattr(pdf.groupby(("a", "x")).rolling(2), f)()
.drop(("a", "x"), axis=1)
.sort_index(),
ps_func(psdf.groupby(("a", "x")).rolling(2)).sort_index(),
pd_func(pdf.groupby(("a", "x")).rolling(2)).drop(("a", "x"), axis=1).sort_index(),
)

self.assert_eq(
getattr(psdf.groupby([("a", "x"), ("a", "y")]).rolling(2), f)().sort_index(),
getattr(pdf.groupby([("a", "x"), ("a", "y")]).rolling(2), f)()
ps_func(psdf.groupby([("a", "x"), ("a", "y")]).rolling(2)).sort_index(),
pd_func(pdf.groupby([("a", "x"), ("a", "y")]).rolling(2))
.drop([("a", "x"), ("a", "y")], axis=1)
.sort_index(),
)
Expand Down
6 changes: 6 additions & 0 deletions python/pyspark/testing/pandasutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ def setUpClass(cls):
super(PandasOnSparkTestCase, cls).setUpClass()
cls.spark.conf.set(SPARK_CONF_ARROW_ENABLED, True)

def convert_str_to_lambda(self, func):
"""
This function coverts `func` str to lambda call
"""
return lambda x: getattr(x, func)()

def assertPandasEqual(self, left, right, check_exact=True):
if isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame):
try:
Expand Down