Skip to content

Commit 5a950c0

Browse files
itholicHyukjinKwon
authored andcommitted
Fix value_counts() to work properly when dropna is True (#1116)
This PR Resolves comment #949 (comment) ```python >>> kdf a b 0 1 NaN 1 2 1.0 2 3 NaN >>> kdf.a 0 1 1 2 2 3 Name: a, dtype: int64 >>> kdf.a.value_counts() 2 1 3 1 1 1 Name: a, dtype: int64 ```
1 parent 063cec5 commit 5a950c0

File tree

2 files changed

+153
-5
lines changed

2 files changed

+153
-5
lines changed

databricks/koalas/base.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020

2121
from functools import wraps, partial
2222
from typing import Union, Callable, Any
23+
from distutils.version import LooseVersion
2324

25+
import pyspark
2426
import numpy as np
2527
import pandas as pd
2628
from pandas.api.types import is_list_like
@@ -33,7 +35,7 @@
3335
from databricks.koalas import numpy_compat
3436
from databricks.koalas.internal import _InternalFrame, SPARK_INDEX_NAME_FORMAT
3537
from databricks.koalas.typedef import pandas_wraps, spark_type_to_pandas_dtype
36-
from databricks.koalas.utils import align_diff_series, scol_for, validate_axis
38+
from databricks.koalas.utils import align_diff_series, scol_for, validate_axis, default_session
3739
from databricks.koalas.frame import DataFrame
3840

3941

@@ -947,15 +949,23 @@ def value_counts(self, normalize=False, sort=True, ascending=False, bins=None, d
947949
Name: koalas, dtype: int64
948950
"""
949951
from databricks.koalas.series import Series, _col
952+
from databricks.koalas.indexes import MultiIndex
953+
if LooseVersion(pyspark.__version__) < LooseVersion("2.4") and \
954+
default_session().conf.get("spark.sql.execution.arrow.enabled") == "true" and \
955+
isinstance(self, MultiIndex):
956+
raise RuntimeError("if you're using pyspark < 2.4, set conf "
957+
"'spark.sql.execution.arrow.enabled' to 'false' "
958+
"for using this function with MultiIndex")
950959
if bins is not None:
951960
raise NotImplementedError("value_counts currently does not support bins")
952961

953962
if dropna:
954-
sdf_dropna = self._internal._sdf.dropna()
963+
sdf_dropna = self._internal._sdf.select(self._scol).dropna()
955964
else:
956-
sdf_dropna = self._internal._sdf
965+
sdf_dropna = self._internal._sdf.select(self._scol)
957966
index_name = SPARK_INDEX_NAME_FORMAT(0)
958-
sdf = sdf_dropna.groupby(self._scol.alias(index_name)).count()
967+
column_name = self._internal.data_columns[0]
968+
sdf = sdf_dropna.groupby(scol_for(sdf_dropna, column_name).alias(index_name)).count()
959969
if sort:
960970
if ascending:
961971
sdf = sdf.orderBy(F.col('count'))

databricks/koalas/tests/test_series.py

+139-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import matplotlib
2525
matplotlib.use('agg')
2626
from matplotlib import pyplot as plt
27+
import pyspark
2728
import numpy as np
2829
import pandas as pd
2930

@@ -33,6 +34,7 @@
3334
from databricks.koalas.exceptions import PandasNotImplementedError
3435
from databricks.koalas.missing.series import _MissingPandasLikeSeries
3536
from databricks.koalas.config import set_option, reset_option
37+
from databricks.koalas.utils import default_session
3638

3739

3840
class SeriesTest(ReusedSQLTestCase, SQLTestUtils):
@@ -243,7 +245,8 @@ def test_nunique(self):
243245
self.assertEqual(ks.Series(range(100)).nunique(approx=True), 103)
244246
self.assertEqual(ks.Series(range(100)).nunique(approx=True, rsd=0.01), 100)
245247

246-
def test_value_counts(self):
248+
def _test_value_counts(self):
249+
# this is also containing test for Index & MultiIndex
247250
pser = pd.Series([1, 2, 1, 3, 3, np.nan, 1, 4], name="x")
248251
kser = ks.from_pandas(pser)
249252

@@ -261,6 +264,15 @@ def test_value_counts(self):
261264
self.assert_eq(kser.value_counts(ascending=True, dropna=False),
262265
pser.value_counts(ascending=True, dropna=False), almost=True)
263266

267+
self.assert_eq(kser.index.value_counts(normalize=True),
268+
pser.index.value_counts(normalize=True), almost=True)
269+
self.assert_eq(kser.index.value_counts(ascending=True),
270+
pser.index.value_counts(ascending=True), almost=True)
271+
self.assert_eq(kser.index.value_counts(normalize=True, dropna=False),
272+
pser.index.value_counts(normalize=True, dropna=False), almost=True)
273+
self.assert_eq(kser.index.value_counts(ascending=True, dropna=False),
274+
pser.index.value_counts(ascending=True, dropna=False), almost=True)
275+
264276
with self.assertRaisesRegex(NotImplementedError,
265277
"value_counts currently does not support bins"):
266278
kser.value_counts(bins=3)
@@ -269,6 +281,132 @@ def test_value_counts(self):
269281
kser.name = 'index'
270282
self.assert_eq(kser.value_counts(), pser.value_counts(), almost=True)
271283

284+
# Series from DataFrame
285+
pdf = pd.DataFrame({'a': [1, 2, 3], 'b': [None, 1, None]})
286+
kdf = ks.from_pandas(pdf)
287+
288+
self.assert_eq(kdf.a.value_counts(normalize=True),
289+
pdf.a.value_counts(normalize=True), almost=True)
290+
self.assert_eq(kdf.a.value_counts(ascending=True),
291+
pdf.a.value_counts(ascending=True), almost=True)
292+
self.assert_eq(kdf.a.value_counts(normalize=True, dropna=False),
293+
pdf.a.value_counts(normalize=True, dropna=False), almost=True)
294+
self.assert_eq(kdf.a.value_counts(ascending=True, dropna=False),
295+
pdf.a.value_counts(ascending=True, dropna=False), almost=True)
296+
297+
self.assert_eq(kser.index.value_counts(normalize=True),
298+
pser.index.value_counts(normalize=True), almost=True)
299+
self.assert_eq(kser.index.value_counts(ascending=True),
300+
pser.index.value_counts(ascending=True), almost=True)
301+
self.assert_eq(kser.index.value_counts(normalize=True, dropna=False),
302+
pser.index.value_counts(normalize=True, dropna=False), almost=True)
303+
self.assert_eq(kser.index.value_counts(ascending=True, dropna=False),
304+
pser.index.value_counts(ascending=True, dropna=False), almost=True)
305+
306+
# Series with NaN index
307+
pser = pd.Series([1, 2, 3], index=[2, None, 5])
308+
kser = ks.from_pandas(pser)
309+
310+
self.assert_eq(kser.value_counts(normalize=True),
311+
pser.value_counts(normalize=True), almost=True)
312+
self.assert_eq(kser.value_counts(ascending=True),
313+
pser.value_counts(ascending=True), almost=True)
314+
self.assert_eq(kser.value_counts(normalize=True, dropna=False),
315+
pser.value_counts(normalize=True, dropna=False), almost=True)
316+
self.assert_eq(kser.value_counts(ascending=True, dropna=False),
317+
pser.value_counts(ascending=True, dropna=False), almost=True)
318+
319+
self.assert_eq(kser.index.value_counts(normalize=True),
320+
pser.index.value_counts(normalize=True), almost=True)
321+
self.assert_eq(kser.index.value_counts(ascending=True),
322+
pser.index.value_counts(ascending=True), almost=True)
323+
self.assert_eq(kser.index.value_counts(normalize=True, dropna=False),
324+
pser.index.value_counts(normalize=True, dropna=False), almost=True)
325+
self.assert_eq(kser.index.value_counts(ascending=True, dropna=False),
326+
pser.index.value_counts(ascending=True, dropna=False), almost=True)
327+
328+
# Series with MultiIndex
329+
pser.index = pd.MultiIndex.from_tuples([('x', 'a'), ('x', 'b'), ('y', 'c')])
330+
kser = ks.from_pandas(pser)
331+
332+
self.assert_eq(kser.value_counts(normalize=True),
333+
pser.value_counts(normalize=True), almost=True)
334+
self.assert_eq(kser.value_counts(ascending=True),
335+
pser.value_counts(ascending=True), almost=True)
336+
self.assert_eq(kser.value_counts(normalize=True, dropna=False),
337+
pser.value_counts(normalize=True, dropna=False), almost=True)
338+
self.assert_eq(kser.value_counts(ascending=True, dropna=False),
339+
pser.value_counts(ascending=True, dropna=False), almost=True)
340+
341+
self.assert_eq(kser.index.value_counts(normalize=True),
342+
pser.index.value_counts(normalize=True), almost=True)
343+
self.assert_eq(kser.index.value_counts(ascending=True),
344+
pser.index.value_counts(ascending=True), almost=True)
345+
self.assert_eq(kser.index.value_counts(normalize=True, dropna=False),
346+
pser.index.value_counts(normalize=True, dropna=False), almost=True)
347+
self.assert_eq(kser.index.value_counts(ascending=True, dropna=False),
348+
pser.index.value_counts(ascending=True, dropna=False), almost=True)
349+
350+
# Series with MultiIndex some of index has NaN
351+
pser.index = pd.MultiIndex.from_tuples([('x', 'a'), ('x', None), ('y', 'c')])
352+
kser = ks.from_pandas(pser)
353+
354+
self.assert_eq(kser.value_counts(normalize=True),
355+
pser.value_counts(normalize=True), almost=True)
356+
self.assert_eq(kser.value_counts(ascending=True),
357+
pser.value_counts(ascending=True), almost=True)
358+
self.assert_eq(kser.value_counts(normalize=True, dropna=False),
359+
pser.value_counts(normalize=True, dropna=False), almost=True)
360+
self.assert_eq(kser.value_counts(ascending=True, dropna=False),
361+
pser.value_counts(ascending=True, dropna=False), almost=True)
362+
363+
self.assert_eq(kser.index.value_counts(normalize=True),
364+
pser.index.value_counts(normalize=True), almost=True)
365+
self.assert_eq(kser.index.value_counts(ascending=True),
366+
pser.index.value_counts(ascending=True), almost=True)
367+
self.assert_eq(kser.index.value_counts(normalize=True, dropna=False),
368+
pser.index.value_counts(normalize=True, dropna=False), almost=True)
369+
self.assert_eq(kser.index.value_counts(ascending=True, dropna=False),
370+
pser.index.value_counts(ascending=True, dropna=False), almost=True)
371+
372+
# Series with MultiIndex some of index is NaN.
373+
# This test only available for pandas >= 0.24.
374+
if LooseVersion(pd.__version__) >= LooseVersion("0.24"):
375+
pser.index = pd.MultiIndex.from_tuples([('x', 'a'), None, ('y', 'c')])
376+
kser = ks.from_pandas(pser)
377+
378+
self.assert_eq(kser.value_counts(normalize=True),
379+
pser.value_counts(normalize=True), almost=True)
380+
self.assert_eq(kser.value_counts(ascending=True),
381+
pser.value_counts(ascending=True), almost=True)
382+
self.assert_eq(kser.value_counts(normalize=True, dropna=False),
383+
pser.value_counts(normalize=True, dropna=False), almost=True)
384+
self.assert_eq(kser.value_counts(ascending=True, dropna=False),
385+
pser.value_counts(ascending=True, dropna=False), almost=True)
386+
387+
self.assert_eq(kser.index.value_counts(normalize=True),
388+
pser.index.value_counts(normalize=True), almost=True)
389+
self.assert_eq(kser.index.value_counts(ascending=True),
390+
pser.index.value_counts(ascending=True), almost=True)
391+
self.assert_eq(kser.index.value_counts(normalize=True, dropna=False),
392+
pser.index.value_counts(normalize=True, dropna=False), almost=True)
393+
self.assert_eq(kser.index.value_counts(ascending=True, dropna=False),
394+
pser.index.value_counts(ascending=True, dropna=False), almost=True)
395+
396+
def test_value_counts(self):
397+
if LooseVersion(pyspark.__version__) < LooseVersion("2.4") and \
398+
default_session().conf.get("spark.sql.execution.arrow.enabled") == "true":
399+
default_session().conf.set("spark.sql.execution.arrow.enabled", "false")
400+
try:
401+
self._test_value_counts()
402+
finally:
403+
default_session().conf.set("spark.sql.execution.arrow.enabled", "true")
404+
self.assertRaises(
405+
RuntimeError,
406+
lambda: ks.MultiIndex.from_tuples([('x', 'a'), ('x', 'b')]).value_counts())
407+
else:
408+
self._test_value_counts()
409+
272410
def test_nsmallest(self):
273411
sample_lst = [1, 2, 3, 4, np.nan, 6]
274412
pser = pd.Series(sample_lst, name='x')

0 commit comments

Comments
 (0)