Skip to content

Commit b7fc773

Browse files
garawalidueshin
authored andcommitted
Update DataFrame.pivot_table() (#635)
Resolves #511, #636. In the test, the `kdf` is converted to Pandas DataFrame in order to use `sort_index()`. I'll update the test once #634 resolved.
1 parent 446d393 commit b7fc773

File tree

2 files changed

+109
-24
lines changed

2 files changed

+109
-24
lines changed

databricks/koalas/frame.py

+73-19
Original file line numberDiff line numberDiff line change
@@ -3551,8 +3551,7 @@ def pivot_table(self, values=None, index=None, columns=None,
35513551
Parameters
35523552
----------
35533553
values : column to aggregate.
3554-
They should be either a list of one column or a string. A list of columns
3555-
is not supported yet.
3554+
They should be either a list less than three or a string.
35563555
index : column (string) or list of columns
35573556
If an array is passed, it must be the same length as the data.
35583557
The list should contain string.
@@ -3601,7 +3600,7 @@ def pivot_table(self, values=None, index=None, columns=None,
36013600
>>> table = df.pivot_table(values='D', index=['A', 'B'],
36023601
... columns='C', aggfunc='sum')
36033602
>>> table # doctest: +NORMALIZE_WHITESPACE
3604-
large small
3603+
C large small
36053604
A B
36063605
foo one 4.0 1
36073606
two NaN 6
@@ -3613,7 +3612,7 @@ def pivot_table(self, values=None, index=None, columns=None,
36133612
>>> table = df.pivot_table(values='D', index=['A', 'B'],
36143613
... columns='C', aggfunc='sum', fill_value=0)
36153614
>>> table # doctest: +NORMALIZE_WHITESPACE
3616-
large small
3615+
C large small
36173616
A B
36183617
foo one 4 1
36193618
two 0 6
@@ -3626,10 +3625,22 @@ def pivot_table(self, values=None, index=None, columns=None,
36263625
>>> table = df.pivot_table(values = ['D'], index =['C'],
36273626
... columns="A", aggfunc={'D':'mean'})
36283627
>>> table # doctest: +NORMALIZE_WHITESPACE
3629-
bar foo
3628+
A bar foo
36303629
C
36313630
small 5.5 2.333333
36323631
large 5.5 2.000000
3632+
3633+
The next example aggregates on multiple values.
3634+
3635+
>>> table = df.pivot_table(index=['C'], columns="A", values=['D', 'E'],
3636+
... aggfunc={'D': 'mean', 'E': 'sum'})
3637+
>>> table # doctest: +NORMALIZE_WHITESPACE
3638+
D E
3639+
A bar foo bar foo
3640+
C
3641+
small 5.5 2.333333 17 13
3642+
large 5.5 2.000000 15 9
3643+
36333644
"""
36343645
if not isinstance(columns, str):
36353646
raise ValueError("columns should be string.")
@@ -3645,13 +3656,24 @@ def pivot_table(self, values=None, index=None, columns=None,
36453656
if isinstance(aggfunc, dict) and index is None:
36463657
raise NotImplementedError("pivot_table doesn't support aggfunc"
36473658
" as dict and without index.")
3659+
if isinstance(values, list) and index is None:
3660+
raise NotImplementedError("values can't be a list without index.")
36483661

3649-
if isinstance(values, list) and len(values) > 1:
3650-
raise NotImplementedError('Values as list of columns is not implemented yet.')
3662+
if isinstance(values, list) and len(values) > 2:
3663+
raise NotImplementedError("values more than two is not supported yet!")
3664+
3665+
if columns not in self.columns.values:
3666+
raise ValueError("Wrong columns {}.".format(columns))
3667+
3668+
if isinstance(values, list):
3669+
if not all(isinstance(self._internal.spark_type_for(col), NumericType)
3670+
for col in values):
3671+
raise TypeError('values should be a numeric type.')
3672+
elif not isinstance(self._internal.spark_type_for(values), NumericType):
3673+
raise TypeError('values should be a numeric type.')
36513674

36523675
if isinstance(aggfunc, str):
36533676
agg_cols = [F.expr('{1}(`{0}`) as `{0}`'.format(values, aggfunc))]
3654-
36553677
elif isinstance(aggfunc, dict):
36563678
agg_cols = [F.expr('{1}(`{0}`) as `{0}`'.format(key, value))
36573679
for key, value in aggfunc.items()]
@@ -3672,20 +3694,52 @@ def pivot_table(self, values=None, index=None, columns=None,
36723694
sdf = sdf.fillna(fill_value)
36733695

36743696
if index is not None:
3675-
data_columns = [column for column in sdf.columns if column not in index]
3676-
index_map = [(column, column) for column in index]
3677-
internal = _InternalFrame(sdf=sdf, data_columns=data_columns, index_map=index_map)
3678-
return DataFrame(internal)
3697+
if isinstance(values, list):
3698+
data_columns = [column for column in sdf.columns if column not in index]
3699+
3700+
if len(values) == 2:
3701+
# If we have two values, Spark will return column's name
3702+
# in this format: column_values, where column contains
3703+
# their values in the DataFrame and values is
3704+
# the column list passed to the pivot_table().
3705+
# E.g. if column is b and values is ['b','e'],
3706+
# then ['2_b', '2_e', '3_b', '3_e'].
3707+
3708+
# We sort the columns of Spark DataFrame by values.
3709+
data_columns.sort(key=lambda x: x.split('_', 1)[1])
3710+
sdf = sdf.select(index + data_columns)
3711+
3712+
index_map = [(column, column) for column in index]
3713+
internal = _InternalFrame(sdf=sdf, data_columns=data_columns,
3714+
index_map=index_map)
3715+
kdf = DataFrame(internal)
3716+
3717+
# We build the MultiIndex from the list of columns returned by Spark.
3718+
tuples = [(name.split('_')[1], self.dtypes[columns].type(name.split('_')[0]))
3719+
for name in kdf._internal.data_columns]
3720+
kdf.columns = pd.MultiIndex.from_tuples(tuples, names=[None, columns])
3721+
else:
3722+
index_map = [(column, column) for column in index]
3723+
internal = _InternalFrame(sdf=sdf, data_columns=data_columns,
3724+
index_map=index_map, column_index_names=[columns])
3725+
kdf = DataFrame(internal)
3726+
return kdf
3727+
else:
3728+
data_columns = [column for column in sdf.columns if column not in index]
3729+
index_map = [(column, column) for column in index]
3730+
internal = _InternalFrame(sdf=sdf, data_columns=data_columns, index_map=index_map,
3731+
column_index_names=[columns])
3732+
return DataFrame(internal)
36793733
else:
36803734
if isinstance(values, list):
36813735
index_values = values[-1]
36823736
else:
36833737
index_values = values
3684-
36853738
sdf = sdf.withColumn(columns, F.lit(index_values))
3686-
data_columns = [column for column in sdf.columns if column not in columns]
3687-
index_map = [(column, column) for column in columns]
3688-
internal = _InternalFrame(sdf=sdf, data_columns=data_columns, index_map=index_map)
3739+
data_columns = [column for column in sdf.columns if column not in [columns]]
3740+
index_map = [(column, column) for column in [columns]]
3741+
internal = _InternalFrame(sdf=sdf, data_columns=data_columns, index_map=index_map,
3742+
column_index_names=[columns])
36893743
return DataFrame(internal)
36903744

36913745
def pivot(self, index=None, columns=None, values=None):
@@ -3736,14 +3790,14 @@ def pivot(self, index=None, columns=None, values=None):
37363790
37373791
>>> df.pivot(index='foo', columns='bar', values='baz').sort_index()
37383792
... # doctest: +NORMALIZE_WHITESPACE
3739-
A B C
3793+
bar A B C
37403794
foo
37413795
one 1 2 3
37423796
two 4 5 6
37433797
37443798
>>> df.pivot(columns='bar', values='baz').sort_index()
37453799
... # doctest: +NORMALIZE_WHITESPACE
3746-
A B C
3800+
bar A B C
37473801
0 1.0 NaN NaN
37483802
1 NaN 2.0 NaN
37493803
2 NaN NaN 3.0
@@ -3768,7 +3822,7 @@ def pivot(self, index=None, columns=None, values=None):
37683822
37693823
>>> df.pivot(index='foo', columns='bar', values='baz').sort_index()
37703824
... # doctest: +NORMALIZE_WHITESPACE
3771-
A B C
3825+
bar A B C
37723826
foo
37733827
one 1.0 NaN NaN
37743828
two NaN 3.0 4.0

databricks/koalas/tests/test_dataframe.py

+36-5
Original file line numberDiff line numberDiff line change
@@ -1102,9 +1102,10 @@ def test_pivot_table(self):
11021102
# Todo: self.assert_eq(kdf.pivot_table(index=['c'], columns="a", values="b"),
11031103
# pdf.pivot_table(index=['c'], columns=["a"], values="b"))
11041104

1105-
# Todo: self.assert_eq(kdf.pivot_table(index=['c'], columns="a", values=['b', 'e'],
1106-
# aggfunc={'b': 'mean', 'e': 'sum'}), pdf.pivot_table(index=['c'], columns=["a"],
1107-
# values=['b', 'e'], aggfunc={'b': 'mean', 'e': 'sum'}))
1105+
self.assert_eq(kdf.pivot_table(index=['c'], columns="a", values=['b', 'e'],
1106+
aggfunc={'b': 'mean', 'e': 'sum'}).sort_index(),
1107+
pdf.pivot_table(index=['c'], columns=["a"],
1108+
values=['b', 'e'], aggfunc={'b': 'mean', 'e': 'sum'}))
11081109

11091110
# Todo: self.assert_eq(kdf.pivot_table(index=['e', 'c'], columns="a", values="b"),
11101111
# pdf.pivot_table(index=['e', 'c'], columns="a", values="b"))
@@ -1153,11 +1154,41 @@ def test_pivot_table_errors(self):
11531154
kdf.pivot_table(index=['e', 'c'], columns="a", values='b',
11541155
aggfunc={'b': 'mean', 'e': 'sum'})
11551156

1156-
msg = 'Values as list of columns is not implemented yet.'
1157+
msg = "values can't be a list without index."
11571158
with self.assertRaisesRegex(NotImplementedError, msg):
1158-
kdf.pivot_table(index=['c'], columns="a", values=['b', 'e'],
1159+
kdf.pivot_table(columns="a", values=['b', 'e'])
1160+
1161+
msg = "values more than two is not supported yet!"
1162+
with self.assertRaisesRegex(NotImplementedError, msg):
1163+
kdf.pivot_table(index=['e'], columns="a", values=['b', 'e', 'c'],
1164+
aggfunc={'b': 'mean', 'e': 'sum', 'c': 'sum'})
1165+
1166+
msg = "Wrong columns A."
1167+
with self.assertRaisesRegex(ValueError, msg):
1168+
kdf.pivot_table(index=['c'], columns="A", values=['b', 'e'],
11591169
aggfunc={'b': 'mean', 'e': 'sum'})
11601170

1171+
kdf = ks.DataFrame({"A": ["foo", "foo", "foo", "foo", "foo",
1172+
"bar", "bar", "bar", "bar"],
1173+
"B": ["one", "one", "one", "two", "two",
1174+
"one", "one", "two", "two"],
1175+
"C": ["small", "large", "large", "small",
1176+
"small", "large", "small", "small",
1177+
"large"],
1178+
"D": [1, 2, 2, 3, 3, 4, 5, 6, 7],
1179+
"E": [2, 4, 5, 5, 6, 6, 8, 9, 9]},
1180+
columns=['A', 'B', 'C', 'D', 'E'])
1181+
1182+
msg = "values should be a numeric type."
1183+
with self.assertRaisesRegex(TypeError, msg):
1184+
kdf.pivot_table(index=['C'], columns="A", values=['B', 'E'],
1185+
aggfunc={'B': 'mean', 'E': 'sum'})
1186+
1187+
msg = "values should be a numeric type."
1188+
with self.assertRaisesRegex(TypeError, msg):
1189+
kdf.pivot_table(index=['C'], columns="A", values='B',
1190+
aggfunc={'B': 'mean'})
1191+
11611192
def test_transpose(self):
11621193
pdf1 = pd.DataFrame(
11631194
data={'col1': [1, 2], 'col2': [3, 4]},

0 commit comments

Comments
 (0)