Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix filter for multi-index columns support. #859

Merged
merged 1 commit into from
Oct 2, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions databricks/koalas/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -6838,9 +6838,9 @@ def filter(self, items=None, like=None, regex=None, axis=None):
sdf = sdf.filter(index_scols[0].contains(like))
return DataFrame(self._internal.copy(sdf=sdf))
elif axis in ('columns', 1, None):
data_columns = self._internal.data_columns
output_columns = [c for c in data_columns if like in c]
return self[output_columns]
column_index = self._internal.column_index
output_idx = [idx for idx in column_index if any(like in i for i in idx)]
return self[output_idx]
elif regex is not None:
if axis in ('index', 0):
# TODO: support multi-index here
Expand All @@ -6849,10 +6849,11 @@ def filter(self, items=None, like=None, regex=None, axis=None):
sdf = sdf.filter(index_scols[0].rlike(regex))
return DataFrame(self._internal.copy(sdf=sdf))
elif axis in ('columns', 1, None):
data_columns = self._internal.data_columns
column_index = self._internal.column_index
matcher = re.compile(regex)
output_columns = [c for c in data_columns if matcher.search(c) is not None]
return self[output_columns]
output_idx = [idx for idx in column_index
if any(matcher.search(i) is not None for i in idx)]
return self[output_idx]
else:
raise TypeError("Must pass either `items`, `like`, or `regex`")

Expand Down
4 changes: 2 additions & 2 deletions databricks/koalas/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,14 +720,14 @@ def from_pandas(pdf: pd.DataFrame) -> '_InternalFrame':
for i, name in enumerate(index.names)]
else:
name = index.name
index_map = [(name if name is not None else '__index_level_0__',
index_map = [(str(name) if name is not None else '__index_level_0__',
name if name is None or isinstance(name, tuple) else (name,))]

index_columns = [index_column for index_column, _ in index_map]

reset_index = pdf.reset_index()
reset_index.columns = index_columns + data_columns
schema = StructType([StructField(name, infer_pd_series_spark_type(col),
schema = StructType([StructField(str(name), infer_pd_series_spark_type(col),
nullable=bool(col.isnull().any()))
for name, col in reset_index.iteritems()])
for name, col in reset_index.iteritems():
Expand Down
35 changes: 31 additions & 4 deletions databricks/koalas/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1834,10 +1834,10 @@ def test_filter(self):
self.assert_eq(kdf.filter(like='b', axis='index'), pdf.filter(like='b', axis='index'))
self.assert_eq(kdf.filter(like='c', axis='columns'), pdf.filter(like='c', axis='columns'))

self.assert_eq(
kdf.filter(regex='b.*', axis='index'), pdf.filter(regex='b.*', axis='index'))
self.assert_eq(
kdf.filter(regex='b.*', axis='columns'), pdf.filter(regex='b.*', axis='columns'))
self.assert_eq(kdf.filter(regex='b.*', axis='index'),
pdf.filter(regex='b.*', axis='index'))
self.assert_eq(kdf.filter(regex='b.*', axis='columns'),
pdf.filter(regex='b.*', axis='columns'))

pdf = pdf.set_index('ba', append=True)
kdf = ks.from_pandas(pdf)
Expand All @@ -1863,6 +1863,33 @@ def test_filter(self):
with self.assertRaisesRegex(TypeError, "mutually exclusive"):
kdf.filter(regex='b.*', like="aaa")

# multi-index columns
pdf = pd.DataFrame({
('x', 'aa'): ['aa', 'ab', 'bc', 'bd', 'ce'],
('x', 'ba'): [1, 2, 3, 4, 5],
('y', 'cb'): [1., 2., 3., 4., 5.],
('z', 'db'): [1., np.nan, 3., np.nan, 5.],
})
pdf = pdf.set_index(('x', 'aa'))
kdf = ks.from_pandas(pdf)

self.assert_eq(
kdf.filter(items=['ab', 'aa'], axis=0).sort_index(),
pdf.filter(items=['ab', 'aa'], axis=0).sort_index())
self.assert_eq(
kdf.filter(items=[('x', 'ba'), ('z', 'db')], axis=1).sort_index(),
pdf.filter(items=[('x', 'ba'), ('z', 'db')], axis=1).sort_index())

self.assert_eq(kdf.filter(like='b', axis='index'),
pdf.filter(like='b', axis='index'))
self.assert_eq(kdf.filter(like='c', axis='columns'),
pdf.filter(like='c', axis='columns'))

self.assert_eq(kdf.filter(regex='b.*', axis='index'),
pdf.filter(regex='b.*', axis='index'))
self.assert_eq(kdf.filter(regex='b.*', axis='columns'),
pdf.filter(regex='b.*', axis='columns'))

def test_pipe(self):
kdf = ks.DataFrame({'category': ['A', 'A', 'B'],
'col1': [1, 2, 3],
Expand Down