Skip to content

Commit ec7dcae

Browse files
committed
Support column names (#636)
1 parent 7372652 commit ec7dcae

File tree

1 file changed

+34
-12
lines changed

1 file changed

+34
-12
lines changed

databricks/koalas/frame.py

+34-12
Original file line numberDiff line numberDiff line change
@@ -3595,7 +3595,7 @@ def pivot_table(self, values=None, index=None, columns=None,
35953595
>>> table = df.pivot_table(values='D', index=['A', 'B'],
35963596
... columns='C', aggfunc='sum')
35973597
>>> table # doctest: +NORMALIZE_WHITESPACE
3598-
large small
3598+
C large small
35993599
A B
36003600
foo one 4.0 1
36013601
two NaN 6
@@ -3607,7 +3607,7 @@ def pivot_table(self, values=None, index=None, columns=None,
36073607
>>> table = df.pivot_table(values='D', index=['A', 'B'],
36083608
... columns='C', aggfunc='sum', fill_value=0)
36093609
>>> table # doctest: +NORMALIZE_WHITESPACE
3610-
large small
3610+
C large small
36113611
A B
36123612
foo one 4 1
36133613
two 0 6
@@ -3620,7 +3620,7 @@ def pivot_table(self, values=None, index=None, columns=None,
36203620
>>> table = df.pivot_table(values = ['D'], index =['C'],
36213621
... columns="A", aggfunc={'D':'mean'})
36223622
>>> table # doctest: +NORMALIZE_WHITESPACE
3623-
bar foo
3623+
A bar foo
36243624
C
36253625
small 5.5 2.333333
36263626
large 5.5 2.000000
@@ -3690,30 +3690,52 @@ def pivot_table(self, values=None, index=None, columns=None,
36903690

36913691
if index is not None:
36923692
if isinstance(values, list):
3693+
data_columns = [column for column in sdf.columns if column not in index]
3694+
36933695
if len(values) == 2:
36943696
# If we have two values, Spark will return column's name
36953697
# in this format: column_values, where column contains
36963698
# their values in the DataFrame and values is
36973699
# the column list passed to the pivot_table().
36983700
# E.g. if column is b and values is ['b','e'],
36993701
# then ['2_b', '2_e', '3_b', '3_e'].
3700-
data_columns = [column for column in sdf.columns if column not in index]
3702+
37013703
# We sort the columns of Spark DataFrame by values.
37023704
data_columns.sort(key=lambda x: x.split('_', 1)[1])
37033705
sdf = sdf.select(index + data_columns)
3704-
kdf = DataFrame(sdf).set_index(index)
37053706

3706-
if len(values) == 2:
3707+
index_map = [(column, column) for column in index]
3708+
internal = _InternalFrame(sdf=sdf, data_columns=data_columns,
3709+
index_map=index_map)
3710+
kdf = DataFrame(internal)
3711+
37073712
# We build the MultiIndex from the list of columns returned by Spark.
37083713
tuples = [(name.split('_')[1], self.dtypes[columns].type(name.split('_')[0]))
37093714
for name in kdf._internal.data_columns]
37103715
kdf.columns = pd.MultiIndex.from_tuples(tuples, names=[None, columns])
3711-
3716+
else:
3717+
index_map = [(column, column) for column in index]
3718+
internal = _InternalFrame(sdf=sdf, data_columns=data_columns,
3719+
index_map=index_map, column_index_names=[columns])
3720+
kdf = DataFrame(internal)
37123721
return kdf
37133722
else:
3714-
return DataFrame(sdf).set_index(index)
3723+
data_columns = [column for column in sdf.columns if column not in index]
3724+
index_map = [(column, column) for column in index]
3725+
internal = _InternalFrame(sdf=sdf, data_columns=data_columns, index_map=index_map,
3726+
column_index_names=[columns])
3727+
return DataFrame(internal)
37153728
else:
3716-
return DataFrame(sdf.withColumn(columns, F.lit(values))).set_index(columns)
3729+
if isinstance(values, list):
3730+
index_values = values[-1]
3731+
else:
3732+
index_values = values
3733+
sdf = sdf.withColumn(columns, F.lit(index_values))
3734+
data_columns = [column for column in sdf.columns if column not in [columns]]
3735+
index_map = [(column, column) for column in [columns]]
3736+
internal = _InternalFrame(sdf=sdf, data_columns=data_columns, index_map=index_map,
3737+
column_index_names=[columns])
3738+
return DataFrame(internal)
37173739

37183740
def pivot(self, index=None, columns=None, values=None):
37193741
"""
@@ -3763,14 +3785,14 @@ def pivot(self, index=None, columns=None, values=None):
37633785
37643786
>>> df.pivot(index='foo', columns='bar', values='baz').sort_index()
37653787
... # doctest: +NORMALIZE_WHITESPACE
3766-
A B C
3788+
bar A B C
37673789
foo
37683790
one 1 2 3
37693791
two 4 5 6
37703792
37713793
>>> df.pivot(columns='bar', values='baz').sort_index()
37723794
... # doctest: +NORMALIZE_WHITESPACE
3773-
A B C
3795+
bar A B C
37743796
0 1.0 NaN NaN
37753797
1 NaN 2.0 NaN
37763798
2 NaN NaN 3.0
@@ -3795,7 +3817,7 @@ def pivot(self, index=None, columns=None, values=None):
37953817
37963818
>>> df.pivot(index='foo', columns='bar', values='baz').sort_index()
37973819
... # doctest: +NORMALIZE_WHITESPACE
3798-
A B C
3820+
bar A B C
37993821
foo
38003822
one 1.0 NaN NaN
38013823
two NaN 3.0 4.0

0 commit comments

Comments
 (0)