Skip to content

Commit

Permalink
Fix DataFrame.drop() to remove fields from Spark DataFrame also. (#794)
Browse files Browse the repository at this point in the history
When we drop columns from dataframe with `DataFrame.drop()`,

We can get a dataframe which columns are dropped properly like below.

```python
>>> df
     name   class  max_speed
0  falcon    bird      389.0
1  parrot    bird       24.0
2    lion  mammal       80.5
3  monkey  mammal        NaN
>>>
>>> df = df.drop('name')
>>> df
    class  max_speed
0    bird      389.0
1    bird       24.0
2  mammal       80.5
3  mammal        NaN
```

But when we try to get an internal spark dataframe after then,

it shows us original one which is not delete columns like below.

```
>>> df._sdf.show()
+-----------------+------+------+---------+
|__index_level_0__|  name| class|max_speed|
+-----------------+------+------+---------+
|                0|falcon|  bird|    389.0|
|                1|parrot|  bird|     24.0|
|                2|  lion|mammal|     80.5|
|                3|monkey|mammal|     null|
+-----------------+------+------+---------+
```

(Although I dropped a column 'name' above example, it still shown in internal spark dataframe)

so i think maybe we need to drop them, too.

like:

```
>>> df._sdf.show()
+-----------------+------+---------+
|__index_level_0__| class|max_speed|
+-----------------+------+---------+
|                0|  bird|    389.0|
|                1|  bird|     24.0|
|                2|mammal|     80.5|
|                3|mammal|     null|
+-----------------+------+---------+
```
  • Loading branch information
itholic authored and HyukjinKwon committed Sep 20, 2019
1 parent eb147d1 commit 275463a
Showing 1 changed file with 21 additions and 2 deletions.
23 changes: 21 additions & 2 deletions databricks/koalas/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -4449,6 +4449,21 @@ def drop(self, labels=None, axis=1,
0 1 7
1 2 8
>>> df = ks.DataFrame({'x': [1, 2], 'y': [3, 4], 'z': [5, 6], 'w': [7, 8]},
... columns=['x', 'y', 'z', 'w'])
>>> columns = [('a', 'x'), ('a', 'y'), ('b', 'z'), ('b', 'w')]
>>> df.columns = pd.MultiIndex.from_tuples(columns)
>>> df # doctest: +NORMALIZE_WHITESPACE
a b
x y z w
0 1 3 5 7
1 2 4 6 8
>>> df.drop('a') # doctest: +NORMALIZE_WHITESPACE
b
z w
0 5 7
1 6 8
Notes
-----
Currently only axis = 1 is supported in this function,
Expand All @@ -4472,11 +4487,15 @@ def drop(self, labels=None, axis=1,
if idx[:len(col)] == col)
if len(drop_column_index) == 0:
raise KeyError(columns)
cols, idx = zip(*((column, idx)
cols, idxes = zip(*((column, idx)
for column, idx
in zip(self._internal.data_columns, self._internal.column_index)
if idx not in drop_column_index))
internal = self._internal.copy(data_columns=list(cols), column_index=list(idx))
internal = self._internal.copy(
sdf=self._sdf.select(
self._internal.index_scols + [self._internal.scol_for(idx) for idx in idxes]),
data_columns=list(cols),
column_index=list(idxes))
return DataFrame(internal)
else:
raise ValueError("Need to specify at least one of 'labels' or 'columns'")
Expand Down

0 comments on commit 275463a

Please sign in to comment.