Skip to content

Commit ef04871

Browse files
committed
Fix filter for multi-index columns support.
1 parent cf22a5d commit ef04871

File tree

3 files changed

+40
-12
lines changed

3 files changed

+40
-12
lines changed

databricks/koalas/frame.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -6838,9 +6838,9 @@ def filter(self, items=None, like=None, regex=None, axis=None):
68386838
sdf = sdf.filter(index_scols[0].contains(like))
68396839
return DataFrame(self._internal.copy(sdf=sdf))
68406840
elif axis in ('columns', 1, None):
6841-
data_columns = self._internal.data_columns
6842-
output_columns = [c for c in data_columns if like in c]
6843-
return self[output_columns]
6841+
column_index = self._internal.column_index
6842+
output_idx = [idx for idx in column_index if any(like in i for i in idx)]
6843+
return self[output_idx]
68446844
elif regex is not None:
68456845
if axis in ('index', 0):
68466846
# TODO: support multi-index here
@@ -6849,10 +6849,11 @@ def filter(self, items=None, like=None, regex=None, axis=None):
68496849
sdf = sdf.filter(index_scols[0].rlike(regex))
68506850
return DataFrame(self._internal.copy(sdf=sdf))
68516851
elif axis in ('columns', 1, None):
6852-
data_columns = self._internal.data_columns
6852+
column_index = self._internal.column_index
68536853
matcher = re.compile(regex)
6854-
output_columns = [c for c in data_columns if matcher.search(c) is not None]
6855-
return self[output_columns]
6854+
output_idx = [idx for idx in column_index
6855+
if any(matcher.search(i) is not None for i in idx)]
6856+
return self[output_idx]
68566857
else:
68576858
raise TypeError("Must pass either `items`, `like`, or `regex`")
68586859

databricks/koalas/internal.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -720,14 +720,14 @@ def from_pandas(pdf: pd.DataFrame) -> '_InternalFrame':
720720
for i, name in enumerate(index.names)]
721721
else:
722722
name = index.name
723-
index_map = [(name if name is not None else '__index_level_0__',
723+
index_map = [(str(name) if name is not None else '__index_level_0__',
724724
name if name is None or isinstance(name, tuple) else (name,))]
725725

726726
index_columns = [index_column for index_column, _ in index_map]
727727

728728
reset_index = pdf.reset_index()
729729
reset_index.columns = index_columns + data_columns
730-
schema = StructType([StructField(name, infer_pd_series_spark_type(col),
730+
schema = StructType([StructField(str(name), infer_pd_series_spark_type(col),
731731
nullable=bool(col.isnull().any()))
732732
for name, col in reset_index.iteritems()])
733733
for name, col in reset_index.iteritems():

databricks/koalas/tests/test_dataframe.py

+31-4
Original file line numberDiff line numberDiff line change
@@ -1834,10 +1834,10 @@ def test_filter(self):
18341834
self.assert_eq(kdf.filter(like='b', axis='index'), pdf.filter(like='b', axis='index'))
18351835
self.assert_eq(kdf.filter(like='c', axis='columns'), pdf.filter(like='c', axis='columns'))
18361836

1837-
self.assert_eq(
1838-
kdf.filter(regex='b.*', axis='index'), pdf.filter(regex='b.*', axis='index'))
1839-
self.assert_eq(
1840-
kdf.filter(regex='b.*', axis='columns'), pdf.filter(regex='b.*', axis='columns'))
1837+
self.assert_eq(kdf.filter(regex='b.*', axis='index'),
1838+
pdf.filter(regex='b.*', axis='index'))
1839+
self.assert_eq(kdf.filter(regex='b.*', axis='columns'),
1840+
pdf.filter(regex='b.*', axis='columns'))
18411841

18421842
pdf = pdf.set_index('ba', append=True)
18431843
kdf = ks.from_pandas(pdf)
@@ -1863,6 +1863,33 @@ def test_filter(self):
18631863
with self.assertRaisesRegex(TypeError, "mutually exclusive"):
18641864
kdf.filter(regex='b.*', like="aaa")
18651865

1866+
# multi-index columns
1867+
pdf = pd.DataFrame({
1868+
('x', 'aa'): ['aa', 'ab', 'bc', 'bd', 'ce'],
1869+
('x', 'ba'): [1, 2, 3, 4, 5],
1870+
('y', 'cb'): [1., 2., 3., 4., 5.],
1871+
('z', 'db'): [1., np.nan, 3., np.nan, 5.],
1872+
})
1873+
pdf = pdf.set_index(('x', 'aa'))
1874+
kdf = ks.from_pandas(pdf)
1875+
1876+
self.assert_eq(
1877+
kdf.filter(items=['ab', 'aa'], axis=0).sort_index(),
1878+
pdf.filter(items=['ab', 'aa'], axis=0).sort_index())
1879+
self.assert_eq(
1880+
kdf.filter(items=[('x', 'ba'), ('z', 'db')], axis=1).sort_index(),
1881+
pdf.filter(items=[('x', 'ba'), ('z', 'db')], axis=1).sort_index())
1882+
1883+
self.assert_eq(kdf.filter(like='b', axis='index'),
1884+
pdf.filter(like='b', axis='index'))
1885+
self.assert_eq(kdf.filter(like='c', axis='columns'),
1886+
pdf.filter(like='c', axis='columns'))
1887+
1888+
self.assert_eq(kdf.filter(regex='b.*', axis='index'),
1889+
pdf.filter(regex='b.*', axis='index'))
1890+
self.assert_eq(kdf.filter(regex='b.*', axis='columns'),
1891+
pdf.filter(regex='b.*', axis='columns'))
1892+
18661893
def test_pipe(self):
18671894
kdf = ks.DataFrame({'category': ['A', 'A', 'B'],
18681895
'col1': [1, 2, 3],

0 commit comments

Comments
 (0)