Skip to content

Commit 7070bc6

Browse files
itholicHyukjinKwon
authored andcommitted
Implement DataFrame.where() & DataFrame.mask() (#1018)
Resolves #884 This PR implement `where` of `DataFrame` (https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.where.html#pandas.DataFrame.where) and `mask` of `DataFrame` (same as where except for the opposite cond) (https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.mask.html#pandas.DataFrame.mask) ```python >>> df1 = ks.DataFrame({'A': [0, 1, 2, 3, 4], 'B':[100, 200, 300, 400, 500]}) >>> df2 = ks.DataFrame({'A': [0, -1, -2, -3, -4], 'B':[-100, -200, -300, -400, -500]}) >>> df1 A B 0 0 100 1 1 200 2 2 300 3 3 400 4 4 500 >>> df2 A B 0 0 -100 1 -1 -200 2 -2 -300 3 -3 -400 4 -4 -500 >>> df1.where(df1 > 0).sort_index() A B 0 NaN 100.0 1 1.0 200.0 2 2.0 300.0 3 3.0 400.0 4 4.0 500.0 >>> df1.where(df1 > 1, 10).sort_index() A B 0 10 100 1 10 200 2 2 300 3 3 400 4 4 500 >>> df1.where(df1 > 1, df1 + 100).sort_index() A B 0 100 100 1 101 200 2 2 300 3 3 400 4 4 500 >>> df1.where(df1 > 1, df2).sort_index() A B 0 0 100 1 -1 200 2 2 300 3 3 400 4 4 500 ```
1 parent d8dfa71 commit 7070bc6

File tree

5 files changed

+329
-2
lines changed

5 files changed

+329
-2
lines changed

databricks/koalas/frame.py

+243
Original file line numberDiff line numberDiff line change
@@ -2042,6 +2042,249 @@ class locomotion
20422042

20432043
return result
20442044

2045+
def where(self, cond, other=np.nan):
2046+
"""
2047+
Replace values where the condition is False.
2048+
2049+
Parameters
2050+
----------
2051+
cond : boolean DataFrame
2052+
Where cond is True, keep the original value. Where False,
2053+
replace with corresponding value from other.
2054+
other : scalar, DataFrame
2055+
Entries where cond is False are replaced with corresponding value from other.
2056+
2057+
Returns
2058+
-------
2059+
DataFrame
2060+
2061+
Examples
2062+
--------
2063+
2064+
>>> from databricks.koalas.config import set_option, reset_option
2065+
>>> set_option("compute.ops_on_diff_frames", True)
2066+
>>> df1 = ks.DataFrame({'A': [0, 1, 2, 3, 4], 'B':[100, 200, 300, 400, 500]})
2067+
>>> df2 = ks.DataFrame({'A': [0, -1, -2, -3, -4], 'B':[-100, -200, -300, -400, -500]})
2068+
>>> df1
2069+
A B
2070+
0 0 100
2071+
1 1 200
2072+
2 2 300
2073+
3 3 400
2074+
4 4 500
2075+
>>> df2
2076+
A B
2077+
0 0 -100
2078+
1 -1 -200
2079+
2 -2 -300
2080+
3 -3 -400
2081+
4 -4 -500
2082+
2083+
>>> df1.where(df1 > 0).sort_index()
2084+
A B
2085+
0 NaN 100.0
2086+
1 1.0 200.0
2087+
2 2.0 300.0
2088+
3 3.0 400.0
2089+
4 4.0 500.0
2090+
2091+
>>> df1.where(df1 > 1, 10).sort_index()
2092+
A B
2093+
0 10 100
2094+
1 10 200
2095+
2 2 300
2096+
3 3 400
2097+
4 4 500
2098+
2099+
>>> df1.where(df1 > 1, df1 + 100).sort_index()
2100+
A B
2101+
0 100 100
2102+
1 101 200
2103+
2 2 300
2104+
3 3 400
2105+
4 4 500
2106+
2107+
>>> df1.where(df1 > 1, df2).sort_index()
2108+
A B
2109+
0 0 100
2110+
1 -1 200
2111+
2 2 300
2112+
3 3 400
2113+
4 4 500
2114+
2115+
When the column name of cond is different from self, it treats all values are False
2116+
2117+
>>> cond = ks.DataFrame({'C': [0, -1, -2, -3, -4], 'D':[4, 3, 2, 1, 0]}) % 3 == 0
2118+
>>> cond
2119+
C D
2120+
0 True False
2121+
1 False True
2122+
2 False False
2123+
3 True False
2124+
4 False True
2125+
2126+
>>> df1.where(cond).sort_index()
2127+
A B
2128+
0 NaN NaN
2129+
1 NaN NaN
2130+
2 NaN NaN
2131+
3 NaN NaN
2132+
4 NaN NaN
2133+
2134+
When the type of cond is Series, it just check boolean regardless of column name
2135+
2136+
>>> cond = ks.Series([1, 2]) > 1
2137+
>>> cond
2138+
0 False
2139+
1 True
2140+
Name: 0, dtype: bool
2141+
2142+
>>> df1.where(cond).sort_index()
2143+
A B
2144+
0 NaN NaN
2145+
1 1.0 200.0
2146+
2 NaN NaN
2147+
3 NaN NaN
2148+
4 NaN NaN
2149+
2150+
>>> reset_option("compute.ops_on_diff_frames")
2151+
"""
2152+
from databricks.koalas.series import Series
2153+
tmp_cond_col_name = '__tmp_cond_col_{}__'
2154+
tmp_other_col_name = '__tmp_other_col_{}__'
2155+
kdf = self.copy()
2156+
if isinstance(cond, DataFrame):
2157+
for column in self._internal.data_columns:
2158+
kdf[tmp_cond_col_name.format(column)] = cond.get(column, False)
2159+
elif isinstance(cond, Series):
2160+
for column in self._internal.data_columns:
2161+
kdf[tmp_cond_col_name.format(column)] = cond
2162+
else:
2163+
raise ValueError("type of cond must be a DataFrame or Series")
2164+
2165+
if isinstance(other, DataFrame):
2166+
for column in self._internal.data_columns:
2167+
kdf[tmp_other_col_name.format(column)] = other.get(column, np.nan)
2168+
else:
2169+
for column in self._internal.data_columns:
2170+
kdf[tmp_other_col_name.format(column)] = other
2171+
2172+
sdf = kdf._sdf
2173+
# above logic make spark dataframe looks like below:
2174+
# +-----------------+---+---+------------------+-------------------+------------------+--...
2175+
# |__index_level_0__| A| B|__tmp_cond_col_A__|__tmp_other_col_A__|__tmp_cond_col_B__|__...
2176+
# +-----------------+---+---+------------------+-------------------+------------------+--...
2177+
# | 0| 0|100| true| 0| false| ...
2178+
# | 1| 1|200| false| -1| false| ...
2179+
# | 3| 3|400| true| -3| false| ...
2180+
# | 2| 2|300| false| -2| true| ...
2181+
# | 4| 4|500| false| -4| false| ...
2182+
# +-----------------+---+---+------------------+-------------------+------------------+--...
2183+
2184+
output = []
2185+
for column in self._internal.data_columns:
2186+
data_col_name = self._internal.column_name_for(column)
2187+
output.append(
2188+
F.when(
2189+
sdf[tmp_cond_col_name.format(column)], sdf[data_col_name]
2190+
).otherwise(
2191+
sdf[tmp_other_col_name.format(column)]
2192+
).alias(data_col_name))
2193+
2194+
index_columns = self._internal.index_columns
2195+
sdf = sdf.select(*index_columns, *output)
2196+
2197+
return DataFrame(self._internal.copy(
2198+
sdf=sdf,
2199+
column_scols=[scol_for(sdf, column) for column in self._internal.data_columns]))
2200+
2201+
def mask(self, cond, other=np.nan):
2202+
"""
2203+
Replace values where the condition is True.
2204+
2205+
Parameters
2206+
----------
2207+
cond : boolean DataFrame
2208+
Where cond is False, keep the original value. Where True,
2209+
replace with corresponding value from other.
2210+
other : scalar, DataFrame
2211+
Entries where cond is True are replaced with corresponding value from other.
2212+
2213+
Returns
2214+
-------
2215+
DataFrame
2216+
2217+
Examples
2218+
--------
2219+
2220+
>>> from databricks.koalas.config import set_option, reset_option
2221+
>>> set_option("compute.ops_on_diff_frames", True)
2222+
>>> df1 = ks.DataFrame({'A': [0, 1, 2, 3, 4], 'B':[100, 200, 300, 400, 500]})
2223+
>>> df2 = ks.DataFrame({'A': [0, -1, -2, -3, -4], 'B':[-100, -200, -300, -400, -500]})
2224+
>>> df1
2225+
A B
2226+
0 0 100
2227+
1 1 200
2228+
2 2 300
2229+
3 3 400
2230+
4 4 500
2231+
>>> df2
2232+
A B
2233+
0 0 -100
2234+
1 -1 -200
2235+
2 -2 -300
2236+
3 -3 -400
2237+
4 -4 -500
2238+
2239+
>>> df1.mask(df1 > 0).sort_index()
2240+
A B
2241+
0 0.0 NaN
2242+
1 NaN NaN
2243+
2 NaN NaN
2244+
3 NaN NaN
2245+
4 NaN NaN
2246+
2247+
>>> df1.mask(df1 > 1, 10).sort_index()
2248+
A B
2249+
0 0 10
2250+
1 1 10
2251+
2 10 10
2252+
3 10 10
2253+
4 10 10
2254+
2255+
>>> df1.mask(df1 > 1, df1 + 100).sort_index()
2256+
A B
2257+
0 0 200
2258+
1 1 300
2259+
2 102 400
2260+
3 103 500
2261+
4 104 600
2262+
2263+
>>> df1.mask(df1 > 1, df2).sort_index()
2264+
A B
2265+
0 0 -100
2266+
1 1 -200
2267+
2 -2 -300
2268+
3 -3 -400
2269+
4 -4 -500
2270+
2271+
>>> reset_option("compute.ops_on_diff_frames")
2272+
"""
2273+
from databricks.koalas.series import Series
2274+
if not isinstance(cond, (DataFrame, Series)):
2275+
raise ValueError("type of cond must be a DataFrame or Series")
2276+
2277+
sdf = cond._internal.sdf
2278+
for col in cond._internal.data_columns:
2279+
sdf = sdf.withColumn(col, ~F.col(col))
2280+
2281+
internal = self._internal.copy(
2282+
sdf=sdf,
2283+
column_scols=[scol_for(sdf, column) for column in self._internal.data_columns])
2284+
cond_inversed = DataFrame(internal)
2285+
2286+
return self.where(cond_inversed, other)
2287+
20452288
@property
20462289
def index(self):
20472290
"""The index (row labels) Column of the DataFrame.

databricks/koalas/missing/frame.py

-2
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ class _MissingPandasLikeDataFrame(object):
6969
last_valid_index = unsupported_function('last_valid_index')
7070
lookup = unsupported_function('lookup')
7171
mad = unsupported_function('mad')
72-
mask = unsupported_function('mask')
7372
mode = unsupported_function('mode')
7473
pct_change = unsupported_function('pct_change')
7574
prod = unsupported_function('prod')
@@ -100,7 +99,6 @@ class _MissingPandasLikeDataFrame(object):
10099
tz_convert = unsupported_function('tz_convert')
101100
tz_localize = unsupported_function('tz_localize')
102101
unstack = unsupported_function('unstack')
103-
where = unsupported_function('where')
104102

105103
# Deprecated functions
106104
as_blocks = unsupported_function('as_blocks', deprecated=True)

databricks/koalas/tests/test_dataframe.py

+12
Original file line numberDiff line numberDiff line change
@@ -2227,3 +2227,15 @@ def test_quantile(self):
22272227

22282228
with self.assertRaisesRegex(ValueError, "quantile currently doesn't supports numeric_only"):
22292229
kdf.quantile(.5, numeric_only=False)
2230+
2231+
def test_where(self):
2232+
kdf = ks.from_pandas(self.pdf)
2233+
2234+
with self.assertRaisesRegex(ValueError, 'type of cond must be a DataFrame or Series'):
2235+
kdf.where(1)
2236+
2237+
def test_mask(self):
2238+
kdf = ks.from_pandas(self.pdf)
2239+
2240+
with self.assertRaisesRegex(ValueError, 'type of cond must be a DataFrame or Series'):
2241+
kdf.mask(1)

databricks/koalas/tests/test_ops_on_diff_frames.py

+72
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,40 @@ def test_loc_setitem(self):
422422

423423
self.assert_eq(kdf.sort_index(), pdf.sort_index())
424424

425+
def test_where(self):
426+
pdf1 = pd.DataFrame({'A': [0, 1, 2, 3, 4], 'B': [100, 200, 300, 400, 500]})
427+
pdf2 = pd.DataFrame({'A': [0, -1, -2, -3, -4], 'B': [-100, -200, -300, -400, -500]})
428+
kdf1 = ks.from_pandas(pdf1)
429+
kdf2 = ks.from_pandas(pdf2)
430+
431+
self.assert_eq(repr(pdf1.where(pdf2 > 100)),
432+
repr(kdf1.where(kdf2 > 100).sort_index()))
433+
434+
pdf1 = pd.DataFrame({'A': [-1, -2, -3, -4, -5], 'B': [-100, -200, -300, -400, -500]})
435+
pdf2 = pd.DataFrame({'A': [-10, -20, -30, -40, -50], 'B': [-5, -4, -3, -2, -1]})
436+
kdf1 = ks.from_pandas(pdf1)
437+
kdf2 = ks.from_pandas(pdf2)
438+
439+
self.assert_eq(repr(pdf1.where(pdf2 < -250)),
440+
repr(kdf1.where(kdf2 < -250).sort_index()))
441+
442+
def test_mask(self):
443+
pdf1 = pd.DataFrame({'A': [0, 1, 2, 3, 4], 'B': [100, 200, 300, 400, 500]})
444+
pdf2 = pd.DataFrame({'A': [0, -1, -2, -3, -4], 'B': [-100, -200, -300, -400, -500]})
445+
kdf1 = ks.from_pandas(pdf1)
446+
kdf2 = ks.from_pandas(pdf2)
447+
448+
self.assert_eq(repr(pdf1.mask(pdf2 < 100)),
449+
repr(kdf1.mask(kdf2 < 100).sort_index()))
450+
451+
pdf1 = pd.DataFrame({'A': [-1, -2, -3, -4, -5], 'B': [-100, -200, -300, -400, -500]})
452+
pdf2 = pd.DataFrame({'A': [-10, -20, -30, -40, -50], 'B': [-5, -4, -3, -2, -1]})
453+
kdf1 = ks.from_pandas(pdf1)
454+
kdf2 = ks.from_pandas(pdf2)
455+
456+
self.assert_eq(repr(pdf1.mask(pdf2 > -250)),
457+
repr(kdf1.mask(kdf2 > -250).sort_index()))
458+
425459
def test_multi_index_column_assignment_frame(self):
426460
pdf = pd.DataFrame({'a': [1, 2, 3, 2], 'b': [4.0, 2.0, 3.0, 1.0]})
427461
pdf.columns = pd.MultiIndex.from_tuples([('a', 'x'), ('a', 'y')])
@@ -493,3 +527,41 @@ def test_loc_setitem(self):
493527

494528
with self.assertRaisesRegex(ValueError, "Cannot combine the series or dataframe"):
495529
kdf.loc[['viper', 'sidewinder'], ['shield']] = another_kdf.max_speed
530+
531+
def test_where(self):
532+
pdf1 = pd.DataFrame({'A': [0, 1, 2, 3, 4], 'B': [100, 200, 300, 400, 500]})
533+
pdf2 = pd.DataFrame({'A': [0, -1, -2, -3, -4], 'B': [-100, -200, -300, -400, -500]})
534+
kdf1 = ks.from_pandas(pdf1)
535+
kdf2 = ks.from_pandas(pdf2)
536+
537+
with self.assertRaisesRegex(ValueError, "Cannot combine the series or dataframe"):
538+
self.assert_eq(repr(pdf1.where(pdf2 > 100)),
539+
repr(kdf1.where(kdf2 > 100).sort_index()))
540+
541+
pdf1 = pd.DataFrame({'A': [-1, -2, -3, -4, -5], 'B': [-100, -200, -300, -400, -500]})
542+
pdf2 = pd.DataFrame({'A': [-10, -20, -30, -40, -50], 'B': [-5, -4, -3, -2, -1]})
543+
kdf1 = ks.from_pandas(pdf1)
544+
kdf2 = ks.from_pandas(pdf2)
545+
546+
with self.assertRaisesRegex(ValueError, "Cannot combine the series or dataframe"):
547+
self.assert_eq(repr(pdf1.where(pdf2 < -250)),
548+
repr(kdf1.where(kdf2 < -250).sort_index()))
549+
550+
def test_mask(self):
551+
pdf1 = pd.DataFrame({'A': [0, 1, 2, 3, 4], 'B': [100, 200, 300, 400, 500]})
552+
pdf2 = pd.DataFrame({'A': [0, -1, -2, -3, -4], 'B': [-100, -200, -300, -400, -500]})
553+
kdf1 = ks.from_pandas(pdf1)
554+
kdf2 = ks.from_pandas(pdf2)
555+
556+
with self.assertRaisesRegex(ValueError, "Cannot combine the series or dataframe"):
557+
self.assert_eq(repr(pdf1.mask(pdf2 < 100)),
558+
repr(kdf1.mask(kdf2 < 100).sort_index()))
559+
560+
pdf1 = pd.DataFrame({'A': [-1, -2, -3, -4, -5], 'B': [-100, -200, -300, -400, -500]})
561+
pdf2 = pd.DataFrame({'A': [-10, -20, -30, -40, -50], 'B': [-5, -4, -3, -2, -1]})
562+
kdf1 = ks.from_pandas(pdf1)
563+
kdf2 = ks.from_pandas(pdf2)
564+
565+
with self.assertRaisesRegex(ValueError, "Cannot combine the series or dataframe"):
566+
self.assert_eq(repr(pdf1.mask(pdf2 > -250)),
567+
repr(kdf1.mask(kdf2 > -250).sort_index()))

docs/source/reference/frame.rst

+2
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ Indexing, iteration
5858
DataFrame.keys
5959
DataFrame.xs
6060
DataFrame.get
61+
DataFrame.where
62+
DataFrame.mask
6163

6264
Binary operator functions
6365
-------------------------

0 commit comments

Comments
 (0)