@@ -3595,7 +3595,7 @@ def pivot_table(self, values=None, index=None, columns=None,
3595
3595
>>> table = df.pivot_table(values='D', index=['A', 'B'],
3596
3596
... columns='C', aggfunc='sum')
3597
3597
>>> table # doctest: +NORMALIZE_WHITESPACE
3598
- large small
3598
+ C large small
3599
3599
A B
3600
3600
foo one 4.0 1
3601
3601
two NaN 6
@@ -3607,7 +3607,7 @@ def pivot_table(self, values=None, index=None, columns=None,
3607
3607
>>> table = df.pivot_table(values='D', index=['A', 'B'],
3608
3608
... columns='C', aggfunc='sum', fill_value=0)
3609
3609
>>> table # doctest: +NORMALIZE_WHITESPACE
3610
- large small
3610
+ C large small
3611
3611
A B
3612
3612
foo one 4 1
3613
3613
two 0 6
@@ -3620,7 +3620,7 @@ def pivot_table(self, values=None, index=None, columns=None,
3620
3620
>>> table = df.pivot_table(values = ['D'], index =['C'],
3621
3621
... columns="A", aggfunc={'D':'mean'})
3622
3622
>>> table # doctest: +NORMALIZE_WHITESPACE
3623
- bar foo
3623
+ A bar foo
3624
3624
C
3625
3625
small 5.5 2.333333
3626
3626
large 5.5 2.000000
@@ -3690,30 +3690,52 @@ def pivot_table(self, values=None, index=None, columns=None,
3690
3690
3691
3691
if index is not None :
3692
3692
if isinstance (values , list ):
3693
+ data_columns = [column for column in sdf .columns if column not in index ]
3694
+
3693
3695
if len (values ) == 2 :
3694
3696
# If we have two values, Spark will return column's name
3695
3697
# in this format: column_values, where column contains
3696
3698
# their values in the DataFrame and values is
3697
3699
# the column list passed to the pivot_table().
3698
3700
# E.g. if column is b and values is ['b','e'],
3699
3701
# then ['2_b', '2_e', '3_b', '3_e'].
3700
- data_columns = [ column for column in sdf . columns if column not in index ]
3702
+
3701
3703
# We sort the columns of Spark DataFrame by values.
3702
3704
data_columns .sort (key = lambda x : x .split ('_' , 1 )[1 ])
3703
3705
sdf = sdf .select (index + data_columns )
3704
- kdf = DataFrame (sdf ).set_index (index )
3705
3706
3706
- if len (values ) == 2 :
3707
+ index_map = [(column , column ) for column in index ]
3708
+ internal = _InternalFrame (sdf = sdf , data_columns = data_columns ,
3709
+ index_map = index_map )
3710
+ kdf = DataFrame (internal )
3711
+
3707
3712
# We build the MultiIndex from the list of columns returned by Spark.
3708
3713
tuples = [(name .split ('_' )[1 ], self .dtypes [columns ].type (name .split ('_' )[0 ]))
3709
3714
for name in kdf ._internal .data_columns ]
3710
3715
kdf .columns = pd .MultiIndex .from_tuples (tuples , names = [None , columns ])
3711
-
3716
+ else :
3717
+ index_map = [(column , column ) for column in index ]
3718
+ internal = _InternalFrame (sdf = sdf , data_columns = data_columns ,
3719
+ index_map = index_map , column_index_names = [columns ])
3720
+ kdf = DataFrame (internal )
3712
3721
return kdf
3713
3722
else :
3714
- return DataFrame (sdf ).set_index (index )
3723
+ data_columns = [column for column in sdf .columns if column not in index ]
3724
+ index_map = [(column , column ) for column in index ]
3725
+ internal = _InternalFrame (sdf = sdf , data_columns = data_columns , index_map = index_map ,
3726
+ column_index_names = [columns ])
3727
+ return DataFrame (internal )
3715
3728
else :
3716
- return DataFrame (sdf .withColumn (columns , F .lit (values ))).set_index (columns )
3729
+ if isinstance (values , list ):
3730
+ index_values = values [- 1 ]
3731
+ else :
3732
+ index_values = values
3733
+ sdf = sdf .withColumn (columns , F .lit (index_values ))
3734
+ data_columns = [column for column in sdf .columns if column not in [columns ]]
3735
+ index_map = [(column , column ) for column in [columns ]]
3736
+ internal = _InternalFrame (sdf = sdf , data_columns = data_columns , index_map = index_map ,
3737
+ column_index_names = [columns ])
3738
+ return DataFrame (internal )
3717
3739
3718
3740
def pivot (self , index = None , columns = None , values = None ):
3719
3741
"""
@@ -3763,14 +3785,14 @@ def pivot(self, index=None, columns=None, values=None):
3763
3785
3764
3786
>>> df.pivot(index='foo', columns='bar', values='baz').sort_index()
3765
3787
... # doctest: +NORMALIZE_WHITESPACE
3766
- A B C
3788
+ bar A B C
3767
3789
foo
3768
3790
one 1 2 3
3769
3791
two 4 5 6
3770
3792
3771
3793
>>> df.pivot(columns='bar', values='baz').sort_index()
3772
3794
... # doctest: +NORMALIZE_WHITESPACE
3773
- A B C
3795
+ bar A B C
3774
3796
0 1.0 NaN NaN
3775
3797
1 NaN 2.0 NaN
3776
3798
2 NaN NaN 3.0
@@ -3795,7 +3817,7 @@ def pivot(self, index=None, columns=None, values=None):
3795
3817
3796
3818
>>> df.pivot(index='foo', columns='bar', values='baz').sort_index()
3797
3819
... # doctest: +NORMALIZE_WHITESPACE
3798
- A B C
3820
+ bar A B C
3799
3821
foo
3800
3822
one 1.0 NaN NaN
3801
3823
two NaN 3.0 4.0
0 commit comments