@@ -3551,8 +3551,7 @@ def pivot_table(self, values=None, index=None, columns=None,
3551
3551
Parameters
3552
3552
----------
3553
3553
values : column to aggregate.
3554
- They should be either a list of one column or a string. A list of columns
3555
- is not supported yet.
3554
+ They should be either a list less than three or a string.
3556
3555
index : column (string) or list of columns
3557
3556
If an array is passed, it must be the same length as the data.
3558
3557
The list should contain string.
@@ -3601,7 +3600,7 @@ def pivot_table(self, values=None, index=None, columns=None,
3601
3600
>>> table = df.pivot_table(values='D', index=['A', 'B'],
3602
3601
... columns='C', aggfunc='sum')
3603
3602
>>> table # doctest: +NORMALIZE_WHITESPACE
3604
- large small
3603
+ C large small
3605
3604
A B
3606
3605
foo one 4.0 1
3607
3606
two NaN 6
@@ -3613,7 +3612,7 @@ def pivot_table(self, values=None, index=None, columns=None,
3613
3612
>>> table = df.pivot_table(values='D', index=['A', 'B'],
3614
3613
... columns='C', aggfunc='sum', fill_value=0)
3615
3614
>>> table # doctest: +NORMALIZE_WHITESPACE
3616
- large small
3615
+ C large small
3617
3616
A B
3618
3617
foo one 4 1
3619
3618
two 0 6
@@ -3626,10 +3625,22 @@ def pivot_table(self, values=None, index=None, columns=None,
3626
3625
>>> table = df.pivot_table(values = ['D'], index =['C'],
3627
3626
... columns="A", aggfunc={'D':'mean'})
3628
3627
>>> table # doctest: +NORMALIZE_WHITESPACE
3629
- bar foo
3628
+ A bar foo
3630
3629
C
3631
3630
small 5.5 2.333333
3632
3631
large 5.5 2.000000
3632
+
3633
+ The next example aggregates on multiple values.
3634
+
3635
+ >>> table = df.pivot_table(index=['C'], columns="A", values=['D', 'E'],
3636
+ ... aggfunc={'D': 'mean', 'E': 'sum'})
3637
+ >>> table # doctest: +NORMALIZE_WHITESPACE
3638
+ D E
3639
+ A bar foo bar foo
3640
+ C
3641
+ small 5.5 2.333333 17 13
3642
+ large 5.5 2.000000 15 9
3643
+
3633
3644
"""
3634
3645
if not isinstance (columns , str ):
3635
3646
raise ValueError ("columns should be string." )
@@ -3645,13 +3656,24 @@ def pivot_table(self, values=None, index=None, columns=None,
3645
3656
if isinstance (aggfunc , dict ) and index is None :
3646
3657
raise NotImplementedError ("pivot_table doesn't support aggfunc"
3647
3658
" as dict and without index." )
3659
+ if isinstance (values , list ) and index is None :
3660
+ raise NotImplementedError ("values can't be a list without index." )
3648
3661
3649
- if isinstance (values , list ) and len (values ) > 1 :
3650
- raise NotImplementedError ('Values as list of columns is not implemented yet.' )
3662
+ if isinstance (values , list ) and len (values ) > 2 :
3663
+ raise NotImplementedError ("values more than two is not supported yet!" )
3664
+
3665
+ if columns not in self .columns .values :
3666
+ raise ValueError ("Wrong columns {}." .format (columns ))
3667
+
3668
+ if isinstance (values , list ):
3669
+ if not all (isinstance (self ._internal .spark_type_for (col ), NumericType )
3670
+ for col in values ):
3671
+ raise TypeError ('values should be a numeric type.' )
3672
+ elif not isinstance (self ._internal .spark_type_for (values ), NumericType ):
3673
+ raise TypeError ('values should be a numeric type.' )
3651
3674
3652
3675
if isinstance (aggfunc , str ):
3653
3676
agg_cols = [F .expr ('{1}(`{0}`) as `{0}`' .format (values , aggfunc ))]
3654
-
3655
3677
elif isinstance (aggfunc , dict ):
3656
3678
agg_cols = [F .expr ('{1}(`{0}`) as `{0}`' .format (key , value ))
3657
3679
for key , value in aggfunc .items ()]
@@ -3672,20 +3694,52 @@ def pivot_table(self, values=None, index=None, columns=None,
3672
3694
sdf = sdf .fillna (fill_value )
3673
3695
3674
3696
if index is not None :
3675
- data_columns = [column for column in sdf .columns if column not in index ]
3676
- index_map = [(column , column ) for column in index ]
3677
- internal = _InternalFrame (sdf = sdf , data_columns = data_columns , index_map = index_map )
3678
- return DataFrame (internal )
3697
+ if isinstance (values , list ):
3698
+ data_columns = [column for column in sdf .columns if column not in index ]
3699
+
3700
+ if len (values ) == 2 :
3701
+ # If we have two values, Spark will return column's name
3702
+ # in this format: column_values, where column contains
3703
+ # their values in the DataFrame and values is
3704
+ # the column list passed to the pivot_table().
3705
+ # E.g. if column is b and values is ['b','e'],
3706
+ # then ['2_b', '2_e', '3_b', '3_e'].
3707
+
3708
+ # We sort the columns of Spark DataFrame by values.
3709
+ data_columns .sort (key = lambda x : x .split ('_' , 1 )[1 ])
3710
+ sdf = sdf .select (index + data_columns )
3711
+
3712
+ index_map = [(column , column ) for column in index ]
3713
+ internal = _InternalFrame (sdf = sdf , data_columns = data_columns ,
3714
+ index_map = index_map )
3715
+ kdf = DataFrame (internal )
3716
+
3717
+ # We build the MultiIndex from the list of columns returned by Spark.
3718
+ tuples = [(name .split ('_' )[1 ], self .dtypes [columns ].type (name .split ('_' )[0 ]))
3719
+ for name in kdf ._internal .data_columns ]
3720
+ kdf .columns = pd .MultiIndex .from_tuples (tuples , names = [None , columns ])
3721
+ else :
3722
+ index_map = [(column , column ) for column in index ]
3723
+ internal = _InternalFrame (sdf = sdf , data_columns = data_columns ,
3724
+ index_map = index_map , column_index_names = [columns ])
3725
+ kdf = DataFrame (internal )
3726
+ return kdf
3727
+ else :
3728
+ data_columns = [column for column in sdf .columns if column not in index ]
3729
+ index_map = [(column , column ) for column in index ]
3730
+ internal = _InternalFrame (sdf = sdf , data_columns = data_columns , index_map = index_map ,
3731
+ column_index_names = [columns ])
3732
+ return DataFrame (internal )
3679
3733
else :
3680
3734
if isinstance (values , list ):
3681
3735
index_values = values [- 1 ]
3682
3736
else :
3683
3737
index_values = values
3684
-
3685
3738
sdf = sdf .withColumn (columns , F .lit (index_values ))
3686
- data_columns = [column for column in sdf .columns if column not in columns ]
3687
- index_map = [(column , column ) for column in columns ]
3688
- internal = _InternalFrame (sdf = sdf , data_columns = data_columns , index_map = index_map )
3739
+ data_columns = [column for column in sdf .columns if column not in [columns ]]
3740
+ index_map = [(column , column ) for column in [columns ]]
3741
+ internal = _InternalFrame (sdf = sdf , data_columns = data_columns , index_map = index_map ,
3742
+ column_index_names = [columns ])
3689
3743
return DataFrame (internal )
3690
3744
3691
3745
def pivot (self , index = None , columns = None , values = None ):
@@ -3736,14 +3790,14 @@ def pivot(self, index=None, columns=None, values=None):
3736
3790
3737
3791
>>> df.pivot(index='foo', columns='bar', values='baz').sort_index()
3738
3792
... # doctest: +NORMALIZE_WHITESPACE
3739
- A B C
3793
+ bar A B C
3740
3794
foo
3741
3795
one 1 2 3
3742
3796
two 4 5 6
3743
3797
3744
3798
>>> df.pivot(columns='bar', values='baz').sort_index()
3745
3799
... # doctest: +NORMALIZE_WHITESPACE
3746
- A B C
3800
+ bar A B C
3747
3801
0 1.0 NaN NaN
3748
3802
1 NaN 2.0 NaN
3749
3803
2 NaN NaN 3.0
@@ -3768,7 +3822,7 @@ def pivot(self, index=None, columns=None, values=None):
3768
3822
3769
3823
>>> df.pivot(index='foo', columns='bar', values='baz').sort_index()
3770
3824
... # doctest: +NORMALIZE_WHITESPACE
3771
- A B C
3825
+ bar A B C
3772
3826
foo
3773
3827
one 1.0 NaN NaN
3774
3828
two NaN 3.0 4.0
0 commit comments