Skip to content

Commit

Permalink
Add sort parameter to concat. (#1636)
Browse files Browse the repository at this point in the history
Adding `sort` parameter to `ks.concat()`.
  • Loading branch information
ueshin authored Jul 9, 2020
1 parent fd047b5 commit 90642b0
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 46 deletions.
78 changes: 52 additions & 26 deletions databricks/koalas/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from typing import Optional, Union, List, Tuple
from collections import OrderedDict
from collections.abc import Iterable
from distutils.version import LooseVersion
from functools import reduce
import itertools

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -1611,7 +1611,7 @@ def column_name(value):


# TODO: there are many parameters to implement and support. See pandas's pd.concat.
def concat(objs, axis=0, join="outer", ignore_index=False):
def concat(objs, axis=0, join="outer", ignore_index=False, sort=False):
"""
Concatenate pandas objects along a particular axis with optional set logic
along the other axes.
Expand All @@ -1631,6 +1631,8 @@ def concat(objs, axis=0, join="outer", ignore_index=False):
concatenating objects where the concatenation axis does not have
meaningful indexing information. Note the index values on the other
axes are still respected in the join.
sort : bool, default False
Sort non-concatenation axis if it is not already aligned.
Returns
-------
Expand Down Expand Up @@ -1693,14 +1695,12 @@ def concat(objs, axis=0, join="outer", ignore_index=False):
Combine ``DataFrame`` and ``Series`` objects with different columns.
>>> ks.concat([df2, s1, s2])
0 letter number
0 None c 3.0
1 None d 4.0
0 a None NaN
1 b None NaN
0 c None NaN
1 d None NaN
>>> ks.concat([df2, s1])
letter number 0
0 c 3.0 None
1 d 4.0 None
0 None NaN a
1 None NaN b
Combine ``DataFrame`` objects with overlapping columns
and return everything. Columns outside the intersection will
Expand All @@ -1714,6 +1714,15 @@ def concat(objs, axis=0, join="outer", ignore_index=False):
1 d 4 dog
>>> ks.concat([df1, df3])
letter number animal
0 a 1 None
1 b 2 None
0 c 3 cat
1 d 4 dog
Sort the columns.
>>> ks.concat([df1, df3], sort=True)
animal letter number
0 None a 1
1 None b 2
Expand Down Expand Up @@ -1825,6 +1834,9 @@ def resolve_func(kdf, this_column_labels, that_column_labels):
if ignore_index:
concat_kdf.columns = list(map(str, _range(len(concat_kdf.columns))))

if sort:
concat_kdf = concat_kdf.sort_index()

return concat_kdf

# Series, Series ...
Expand All @@ -1834,9 +1846,11 @@ def resolve_func(kdf, this_column_labels, that_column_labels):
# DataFrame, Series ... & Series, Series ...
# In this case, we should return DataFrame.
new_objs = []
num_series = 0
for obj in objs:
if isinstance(obj, Series):
obj = obj.rename(SPARK_DEFAULT_SERIES_NAME).to_dataframe()
num_series += 1
obj = obj.to_frame(SPARK_DEFAULT_SERIES_NAME)
new_objs.append(obj)
objs = new_objs

Expand All @@ -1859,35 +1873,47 @@ def resolve_func(kdf, this_column_labels, that_column_labels):
)
)

column_labelses_of_kdfs = [kdf._internal.column_labels for kdf in objs]
column_labels_of_kdfs = [kdf._internal.column_labels for kdf in objs]
if ignore_index:
index_names_of_kdfs = [[] for _ in objs]
else:
index_names_of_kdfs = [kdf._internal.index_names for kdf in objs]
if all(name == index_names_of_kdfs[0] for name in index_names_of_kdfs) and all(
idx == column_labelses_of_kdfs[0] for idx in column_labelses_of_kdfs
idx == column_labels_of_kdfs[0] for idx in column_labels_of_kdfs
):
# If all columns are in the same order and values, use it.
kdfs = objs
else:
if join == "inner":
interested_columns = set.intersection(*map(set, column_labelses_of_kdfs))
interested_columns = set.intersection(*map(set, column_labels_of_kdfs))
# Keep the column order with its firsts DataFrame.
merged_columns = sorted(
list(
map(
lambda c: column_labelses_of_kdfs[0][column_labelses_of_kdfs[0].index(c)],
interested_columns,
)
)
)
merged_columns = [
label for label in column_labels_of_kdfs[0] if label in interested_columns
]

# When multi-index column, although pandas is flaky if `join="inner" and sort=False`,
# always sort to follow the `join="outer"` case behavior.
if (len(merged_columns) > 0 and len(merged_columns[0]) > 1) or sort:
merged_columns = sorted(merged_columns)

kdfs = [kdf[merged_columns] for kdf in objs]
elif join == "outer":
# If there are columns unmatched, just sort the column names.
merged_columns = sorted(
list(set(itertools.chain.from_iterable(column_labelses_of_kdfs)))
)
merged_columns = []
for labels in column_labels_of_kdfs:
merged_columns.extend(label for label in labels if label not in merged_columns)

assert len(merged_columns) > 0

if LooseVersion(pd.__version__) < LooseVersion("0.24"):
# Always sort when multi-index columns, and if there are Series, never sort.
sort = len(merged_columns[0]) > 1 or (num_series == 0 and sort)
else:
# Always sort when multi-index columns or there are more than two Series,
# and if there is only one Series, never sort.
sort = len(merged_columns[0]) > 1 or num_series > 1 or (num_series != 1 and sort)

if sort:
merged_columns = sorted(merged_columns)

kdfs = []
for kdf in objs:
Expand Down
75 changes: 55 additions & 20 deletions databricks/koalas/tests/test_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,20 +66,35 @@ def test_to_datetime(self):
ks.to_datetime([1, 2, 3], unit="D", origin=pd.Timestamp("1960-01-01")),
)

def test_concat(self):
pdf = pd.DataFrame({"A": [0, 2, 4], "B": [1, 3, 5]})
def test_concat_index_axis(self):
pdf = pd.DataFrame({"A": [0, 2, 4], "B": [1, 3, 5], "C": [6, 7, 8]})
kdf = ks.from_pandas(pdf)

self.assert_eq(ks.concat([kdf, kdf.reset_index()]), pd.concat([pdf, pdf.reset_index()]))
ignore_indexes = [True, False]
joins = ["inner", "outer"]
sorts = [True, False]

self.assert_eq(
ks.concat([kdf, kdf[["A"]]], ignore_index=True),
pd.concat([pdf, pdf[["A"]]], ignore_index=True),
)
objs = [
([kdf, kdf], [pdf, pdf]),
([kdf, kdf.reset_index()], [pdf, pdf.reset_index()]),
([kdf.reset_index(), kdf], [pdf.reset_index(), pdf]),
([kdf, kdf[["C", "A"]]], [pdf, pdf[["C", "A"]]]),
([kdf[["C", "A"]], kdf], [pdf[["C", "A"]], pdf]),
([kdf, kdf["C"]], [pdf, pdf["C"]]),
([kdf["C"], kdf], [pdf["C"], pdf]),
([kdf["C"], kdf, kdf["A"]], [pdf["C"], pdf, pdf["A"]]),
([kdf, kdf["C"], kdf["A"]], [pdf, pdf["C"], pdf["A"]]),
]

self.assert_eq(
ks.concat([kdf, kdf[["A"]]], join="inner"), pd.concat([pdf, pdf[["A"]]], join="inner")
)
for ignore_index, join, sort in itertools.product(ignore_indexes, joins, sorts):
for obj in objs:
kdfs, pdfs = obj
with self.subTest(ignore_index=ignore_index, join=join, sort=sort, objs=pdfs):
self.assert_eq(
ks.concat(kdfs, ignore_index=ignore_index, join=join, sort=sort),
pd.concat(pdfs, ignore_index=ignore_index, join=join, sort=sort),
almost=(join == "outer"),
)

self.assertRaisesRegex(TypeError, "first argument must be", lambda: ks.concat(kdf))
self.assertRaisesRegex(TypeError, "cannot concatenate object", lambda: ks.concat([kdf, 1]))
Expand All @@ -96,27 +111,47 @@ def test_concat(self):
pdf3 = pdf.copy()
kdf3 = kdf.copy()

columns = pd.MultiIndex.from_tuples([("X", "A"), ("X", "B")])
columns = pd.MultiIndex.from_tuples([("X", "A"), ("X", "B"), ("Y", "C")])
pdf3.columns = columns
kdf3.columns = columns

self.assert_eq(ks.concat([kdf3, kdf3.reset_index()]), pd.concat([pdf3, pdf3.reset_index()]))
objs = [
([kdf3, kdf3], [pdf3, pdf3]),
([kdf3, kdf3.reset_index()], [pdf3, pdf3.reset_index()]),
([kdf3.reset_index(), kdf3], [pdf3.reset_index(), pdf3]),
([kdf3, kdf3[[("Y", "C"), ("X", "A")]]], [pdf3, pdf3[[("Y", "C"), ("X", "A")]]]),
([kdf3[[("Y", "C"), ("X", "A")]], kdf3], [pdf3[[("Y", "C"), ("X", "A")]], pdf3]),
]

self.assert_eq(
ks.concat([kdf3, kdf3[[("X", "A")]]], ignore_index=True),
pd.concat([pdf3, pdf3[[("X", "A")]]], ignore_index=True),
)
for ignore_index, sort in itertools.product(ignore_indexes, sorts):
for obj in objs:
kdfs, pdfs = obj
with self.subTest(ignore_index=ignore_index, join="outer", sort=sort, objs=pdfs):
self.assert_eq(
ks.concat(kdfs, ignore_index=ignore_index, join="outer", sort=sort),
pd.concat(pdfs, ignore_index=ignore_index, join="outer", sort=sort),
)

self.assert_eq(
ks.concat([kdf3, kdf3[[("X", "A")]]], join="inner"),
pd.concat([pdf3, pdf3[[("X", "A")]]], join="inner"),
)
# Skip tests for `join="inner" and sort=False` since pandas is flaky.
for ignore_index in ignore_indexes:
for obj in objs:
kdfs, pdfs = obj
with self.subTest(ignore_index=ignore_index, join="inner", sort=True, objs=pdfs):
self.assert_eq(
ks.concat(kdfs, ignore_index=ignore_index, join="inner", sort=True),
pd.concat(pdfs, ignore_index=ignore_index, join="inner", sort=True),
)

self.assertRaisesRegex(
ValueError,
"MultiIndex columns should have the same levels",
lambda: ks.concat([kdf, kdf3]),
)
self.assertRaisesRegex(
ValueError,
"MultiIndex columns should have the same levels",
lambda: ks.concat([kdf3[("Y", "C")], kdf3]),
)

pdf4 = pd.DataFrame({"A": [0, 2, 4], "B": [1, 3, 5], "C": [10, 20, 30]})
kdf4 = ks.from_pandas(pdf4)
Expand Down

0 comments on commit 90642b0

Please sign in to comment.