diff --git a/databricks/koalas/generic.py b/databricks/koalas/generic.py index 87f2463209..dc02cb81ba 100644 --- a/databricks/koalas/generic.py +++ b/databricks/koalas/generic.py @@ -1262,16 +1262,31 @@ def groupby(self, by, as_index: bool = True): from databricks.koalas.groupby import DataFrameGroupBy, SeriesGroupBy df_or_s = self - if isinstance(by, str): + if isinstance(by, DataFrame): + raise ValueError("Grouper for '{}' not 1-dimensional".format(type(by))) + elif isinstance(by, str): + if isinstance(df_or_s, Series): + raise KeyError(by) by = [(by,)] elif isinstance(by, tuple): + if isinstance(df_or_s, Series): + for key in by: + if isinstance(key, str): + raise KeyError(key) + for key in by: + if isinstance(key, DataFrame): + raise ValueError("Grouper for '{}' not 1-dimensional".format(type(key))) by = [by] elif isinstance(by, Series): by = [by] elif isinstance(by, Iterable): + if isinstance(df_or_s, Series): + for key in by: + if isinstance(key, str): + raise KeyError(key) by = [key if isinstance(key, (tuple, Series)) else (key,) for key in by] else: - raise ValueError('Not a valid index: TODO') + raise ValueError("Grouper for '{}' not 1-dimensional".format(type(by))) if not len(by): raise ValueError('No group keys passed!') if isinstance(df_or_s, DataFrame): diff --git a/databricks/koalas/tests/test_groupby.py b/databricks/koalas/tests/test_groupby.py index aebf27fa1f..36f094a8d7 100644 --- a/databricks/koalas/tests/test_groupby.py +++ b/databricks/koalas/tests/test_groupby.py @@ -68,6 +68,16 @@ def test_groupby(self): self.assertRaises(TypeError, lambda: kdf.a.groupby(kdf.b, as_index=False)) + # we can't use column name/names as a parameter `by` for `SeriesGroupBy`. + self.assertRaises(KeyError, lambda: kdf.a.groupby(by='a')) + self.assertRaises(KeyError, lambda: kdf.a.groupby(by=['a', 'b'])) + self.assertRaises(KeyError, lambda: kdf.a.groupby(by=('a', 'b'))) + + # we can't use DataFrame as a parameter `by` for `DataFrameGroupBy`/`SeriesGroupBy`. + self.assertRaises(ValueError, lambda: kdf.groupby(kdf)) + self.assertRaises(ValueError, lambda: kdf.a.groupby(kdf)) + self.assertRaises(ValueError, lambda: kdf.a.groupby((kdf,))) + def test_groupby_multiindex_columns(self): pdf = pd.DataFrame({('x', 'a'): [1, 2, 6, 4, 4, 6, 4, 3, 7], ('x', 'b'): [4, 2, 7, 3, 3, 1, 1, 1, 2], @@ -838,7 +848,7 @@ def test_missing(self): with self.assertRaisesRegex( PandasNotImplementedError, "method.*GroupBy.*{}.*not implemented( yet\\.|\\. .+)".format(name)): - getattr(kdf.a.groupby('a'), name)() + getattr(kdf.a.groupby(kdf.a), name)() deprecated_functions = [name for (name, type_) in missing_functions if type_.__name__ == 'deprecated_function'] @@ -846,7 +856,7 @@ def test_missing(self): with self.assertRaisesRegex(PandasNotImplementedError, "method.*GroupBy.*{}.*is deprecated" .format(name)): - getattr(kdf.a.groupby('a'), name)() + getattr(kdf.a.groupby(kdf.a), name)() # DataFrameGroupBy properties missing_properties = inspect.getmembers(_MissingPandasLikeDataFrameGroupBy, @@ -875,14 +885,14 @@ def test_missing(self): with self.assertRaisesRegex( PandasNotImplementedError, "property.*GroupBy.*{}.*not implemented( yet\\.|\\. .+)".format(name)): - getattr(kdf.a.groupby('a'), name) + getattr(kdf.a.groupby(kdf.a), name) deprecated_properties = [name for (name, type_) in missing_properties if type_.fget.__name__ == 'deprecated_property'] for name in deprecated_properties: with self.assertRaisesRegex(PandasNotImplementedError, "property.*GroupBy.*{}.*is deprecated" .format(name)): - getattr(kdf.a.groupby('a'), name) + getattr(kdf.a.groupby(kdf.a), name) @staticmethod def test_is_multi_agg_with_relabel():