Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update DataFrame.pivot_table() #635

Merged
merged 7 commits into from
Aug 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 73 additions & 19 deletions databricks/koalas/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3546,8 +3546,7 @@ def pivot_table(self, values=None, index=None, columns=None,
Parameters
----------
values : column to aggregate.
They should be either a list of one column or a string. A list of columns
is not supported yet.
They should be either a list less than three or a string.
index : column (string) or list of columns
If an array is passed, it must be the same length as the data.
The list should contain string.
Expand Down Expand Up @@ -3596,7 +3595,7 @@ def pivot_table(self, values=None, index=None, columns=None,
>>> table = df.pivot_table(values='D', index=['A', 'B'],
... columns='C', aggfunc='sum')
>>> table # doctest: +NORMALIZE_WHITESPACE
large small
C large small
A B
foo one 4.0 1
two NaN 6
Expand All @@ -3608,7 +3607,7 @@ def pivot_table(self, values=None, index=None, columns=None,
>>> table = df.pivot_table(values='D', index=['A', 'B'],
... columns='C', aggfunc='sum', fill_value=0)
>>> table # doctest: +NORMALIZE_WHITESPACE
large small
C large small
A B
foo one 4 1
two 0 6
Expand All @@ -3621,10 +3620,22 @@ def pivot_table(self, values=None, index=None, columns=None,
>>> table = df.pivot_table(values = ['D'], index =['C'],
... columns="A", aggfunc={'D':'mean'})
>>> table # doctest: +NORMALIZE_WHITESPACE
bar foo
A bar foo
C
small 5.5 2.333333
large 5.5 2.000000

The next example aggregates on multiple values.

>>> table = df.pivot_table(index=['C'], columns="A", values=['D', 'E'],
... aggfunc={'D': 'mean', 'E': 'sum'})
>>> table # doctest: +NORMALIZE_WHITESPACE
D E
A bar foo bar foo
C
small 5.5 2.333333 17 13
large 5.5 2.000000 15 9

"""
if not isinstance(columns, str):
raise ValueError("columns should be string.")
Expand All @@ -3640,13 +3651,24 @@ def pivot_table(self, values=None, index=None, columns=None,
if isinstance(aggfunc, dict) and index is None:
raise NotImplementedError("pivot_table doesn't support aggfunc"
" as dict and without index.")
if isinstance(values, list) and index is None:
raise NotImplementedError("values can't be a list without index.")

if isinstance(values, list) and len(values) > 1:
raise NotImplementedError('Values as list of columns is not implemented yet.')
if isinstance(values, list) and len(values) > 2:
raise NotImplementedError("values more than two is not supported yet!")

if columns not in self.columns.values:
raise ValueError("Wrong columns {}.".format(columns))

if isinstance(values, list):
if not all(isinstance(self._internal.spark_type_for(col), NumericType)
for col in values):
raise TypeError('values should be a numeric type.')
elif not isinstance(self._internal.spark_type_for(values), NumericType):
raise TypeError('values should be a numeric type.')

if isinstance(aggfunc, str):
agg_cols = [F.expr('{1}(`{0}`) as `{0}`'.format(values, aggfunc))]

elif isinstance(aggfunc, dict):
agg_cols = [F.expr('{1}(`{0}`) as `{0}`'.format(key, value))
for key, value in aggfunc.items()]
Expand All @@ -3667,20 +3689,52 @@ def pivot_table(self, values=None, index=None, columns=None,
sdf = sdf.fillna(fill_value)

if index is not None:
data_columns = [column for column in sdf.columns if column not in index]
index_map = [(column, column) for column in index]
internal = _InternalFrame(sdf=sdf, data_columns=data_columns, index_map=index_map)
return DataFrame(internal)
if isinstance(values, list):
data_columns = [column for column in sdf.columns if column not in index]

if len(values) == 2:
garawalid marked this conversation as resolved.
Show resolved Hide resolved
# If we have two values, Spark will return column's name
# in this format: column_values, where column contains
# their values in the DataFrame and values is
# the column list passed to the pivot_table().
# E.g. if column is b and values is ['b','e'],
# then ['2_b', '2_e', '3_b', '3_e'].

# We sort the columns of Spark DataFrame by values.
data_columns.sort(key=lambda x: x.split('_', 1)[1])
sdf = sdf.select(index + data_columns)

index_map = [(column, column) for column in index]
internal = _InternalFrame(sdf=sdf, data_columns=data_columns,
index_map=index_map)
kdf = DataFrame(internal)

# We build the MultiIndex from the list of columns returned by Spark.
tuples = [(name.split('_')[1], self.dtypes[columns].type(name.split('_')[0]))
for name in kdf._internal.data_columns]
kdf.columns = pd.MultiIndex.from_tuples(tuples, names=[None, columns])
else:
index_map = [(column, column) for column in index]
internal = _InternalFrame(sdf=sdf, data_columns=data_columns,
index_map=index_map, column_index_names=[columns])
kdf = DataFrame(internal)
return kdf
else:
data_columns = [column for column in sdf.columns if column not in index]
index_map = [(column, column) for column in index]
internal = _InternalFrame(sdf=sdf, data_columns=data_columns, index_map=index_map,
column_index_names=[columns])
return DataFrame(internal)
else:
if isinstance(values, list):
index_values = values[-1]
else:
index_values = values

sdf = sdf.withColumn(columns, F.lit(index_values))
data_columns = [column for column in sdf.columns if column not in columns]
index_map = [(column, column) for column in columns]
internal = _InternalFrame(sdf=sdf, data_columns=data_columns, index_map=index_map)
data_columns = [column for column in sdf.columns if column not in [columns]]
index_map = [(column, column) for column in [columns]]
internal = _InternalFrame(sdf=sdf, data_columns=data_columns, index_map=index_map,
column_index_names=[columns])
return DataFrame(internal)

def pivot(self, index=None, columns=None, values=None):
Expand Down Expand Up @@ -3731,14 +3785,14 @@ def pivot(self, index=None, columns=None, values=None):

>>> df.pivot(index='foo', columns='bar', values='baz').sort_index()
... # doctest: +NORMALIZE_WHITESPACE
A B C
bar A B C
foo
one 1 2 3
two 4 5 6

>>> df.pivot(columns='bar', values='baz').sort_index()
... # doctest: +NORMALIZE_WHITESPACE
A B C
bar A B C
0 1.0 NaN NaN
1 NaN 2.0 NaN
2 NaN NaN 3.0
Expand All @@ -3763,7 +3817,7 @@ def pivot(self, index=None, columns=None, values=None):

>>> df.pivot(index='foo', columns='bar', values='baz').sort_index()
... # doctest: +NORMALIZE_WHITESPACE
A B C
bar A B C
foo
one 1.0 NaN NaN
two NaN 3.0 4.0
Expand Down
41 changes: 36 additions & 5 deletions databricks/koalas/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,9 +1102,10 @@ def test_pivot_table(self):
# Todo: self.assert_eq(kdf.pivot_table(index=['c'], columns="a", values="b"),
# pdf.pivot_table(index=['c'], columns=["a"], values="b"))

# Todo: self.assert_eq(kdf.pivot_table(index=['c'], columns="a", values=['b', 'e'],
# aggfunc={'b': 'mean', 'e': 'sum'}), pdf.pivot_table(index=['c'], columns=["a"],
# values=['b', 'e'], aggfunc={'b': 'mean', 'e': 'sum'}))
self.assert_eq(kdf.pivot_table(index=['c'], columns="a", values=['b', 'e'],
aggfunc={'b': 'mean', 'e': 'sum'}).sort_index(),
pdf.pivot_table(index=['c'], columns=["a"],
values=['b', 'e'], aggfunc={'b': 'mean', 'e': 'sum'}))

# Todo: self.assert_eq(kdf.pivot_table(index=['e', 'c'], columns="a", values="b"),
# pdf.pivot_table(index=['e', 'c'], columns="a", values="b"))
Expand Down Expand Up @@ -1153,11 +1154,41 @@ def test_pivot_table_errors(self):
kdf.pivot_table(index=['e', 'c'], columns="a", values='b',
aggfunc={'b': 'mean', 'e': 'sum'})

msg = 'Values as list of columns is not implemented yet.'
msg = "values can't be a list without index."
with self.assertRaisesRegex(NotImplementedError, msg):
kdf.pivot_table(index=['c'], columns="a", values=['b', 'e'],
kdf.pivot_table(columns="a", values=['b', 'e'])

msg = "values more than two is not supported yet!"
with self.assertRaisesRegex(NotImplementedError, msg):
kdf.pivot_table(index=['e'], columns="a", values=['b', 'e', 'c'],
aggfunc={'b': 'mean', 'e': 'sum', 'c': 'sum'})

msg = "Wrong columns A."
with self.assertRaisesRegex(ValueError, msg):
kdf.pivot_table(index=['c'], columns="A", values=['b', 'e'],
aggfunc={'b': 'mean', 'e': 'sum'})

kdf = ks.DataFrame({"A": ["foo", "foo", "foo", "foo", "foo",
"bar", "bar", "bar", "bar"],
"B": ["one", "one", "one", "two", "two",
"one", "one", "two", "two"],
"C": ["small", "large", "large", "small",
"small", "large", "small", "small",
"large"],
"D": [1, 2, 2, 3, 3, 4, 5, 6, 7],
"E": [2, 4, 5, 5, 6, 6, 8, 9, 9]},
columns=['A', 'B', 'C', 'D', 'E'])

msg = "values should be a numeric type."
with self.assertRaisesRegex(TypeError, msg):
kdf.pivot_table(index=['C'], columns="A", values=['B', 'E'],
aggfunc={'B': 'mean', 'E': 'sum'})

msg = "values should be a numeric type."
with self.assertRaisesRegex(TypeError, msg):
kdf.pivot_table(index=['C'], columns="A", values='B',
aggfunc={'B': 'mean'})

def test_transpose(self):
pdf1 = pd.DataFrame(
data={'col1': [1, 2], 'col2': [3, 4]},
Expand Down