Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix DataFrame.join for MultiIndex #1771

Merged
merged 5 commits into from
Sep 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions databricks/koalas/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -6930,16 +6930,24 @@ def join(
raise ValueError(
"columns overlap but no suffix specified: " "{rename}".format(rename=common)
)

need_set_index = False
if on:
if not is_list_like(on):
itholic marked this conversation as resolved.
Show resolved Hide resolved
on = [on] # type: ignore
if len(on) != right.index.nlevels:
raise ValueError(
'len(left_on) must equal the number of levels in the index of "right"'
)

need_set_index = len(set(on) & set(self.index.names)) == 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@itholic Would you please help me understand this line?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xinrong-databricks

Sure!

This line checks if the given join keys are already included in the Index or not.

If not (True, in this statement, because we're checking if the intersection count of the set is 0), we need to set the given join keys as an Index using set_index below.

If you have any questions more, please feel free to ask ! :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! Your explanation is so clear :)!

May I ask for the reason for set the given join keys as an Index?

Copy link
Contributor Author

@itholic itholic Jan 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xinrong-databricks

Sure!

This is because we uses merge after then.

If the given join keys are not in Index, the result of merge will be not correct.

For example, let's say we have two DataFrames as below.

>>> kdf1
  key   A
0  K0  A0
1  K1  A1
2  K2  A2
3  K3  A3

>>> kdf2
      B
key
K0   B0
K1   B1
K2   B2

And we can expect the result of join with on='keys' as below.

>>> kdf1.join(kdf2, on=['key'])
  key   A     B
0  K3  A3  None
1  K0  A0    B0
2  K1  A1    B1
3  K2  A2    B2

We can make the same result with merge as below.

>>> kdf1.set_index('key').merge(kdf2, left_index=True, right_index=True, how='left').reset_index()
  key   A     B
0  K3  A3  None
1  K0  A0    B0
2  K1  A1    B1
3  K2  A2    B2

At this point, If we didn't kdf1.set_index('key'), the result will be different as below.

>>> kdf1.merge(kdf2, left_index=True, right_index=True, how='left').reset_index()
   index key   A     B
0      0  K0  A0  None
1      1  K1  A1  None
2      3  K3  A3  None
3      2  K2  A2  None

So, that's why we need set_index here!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @itholic ! That's so clear :).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xinrong-databricks np! Glad to know I helped :)

if need_set_index:
self = self.set_index(on)
join_kdf = self.merge(
right, left_index=True, right_index=True, how=how, suffixes=(lsuffix, rsuffix)
).reset_index()
else:
join_kdf = self.merge(
right, left_index=True, right_index=True, how=how, suffixes=(lsuffix, rsuffix)
)
return join_kdf

join_kdf = self.merge(
right, left_index=True, right_index=True, how=how, suffixes=(lsuffix, rsuffix)
)
return join_kdf.reset_index() if need_set_index else join_kdf

def append(
self,
Expand Down
59 changes: 52 additions & 7 deletions databricks/koalas/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1576,13 +1576,9 @@ def test_join(self):
pdf2 = pd.DataFrame(
{"key": ["K0", "K1", "K2"], "B": ["B0", "B1", "B2"]}, columns=["key", "B"]
)
kdf1 = ks.DataFrame(
{"key": ["K0", "K1", "K2", "K3"], "A": ["A0", "A1", "A2", "A3"]}, columns=["key", "A"]
)
kdf2 = ks.DataFrame(
{"key": ["K0", "K1", "K2"], "B": ["B0", "B1", "B2"]}, columns=["key", "B"]
)
ks1 = ks.Series(["A1", "A5"], index=[1, 2], name="A")
kdf1 = ks.from_pandas(pdf1)
kdf2 = ks.from_pandas(pdf2)

join_pdf = pdf1.join(pdf2, lsuffix="_left", rsuffix="_right")
join_pdf.sort_values(by=list(join_pdf.columns), inplace=True)

Expand All @@ -1593,6 +1589,7 @@ def test_join(self):

# join with duplicated columns in Series
with self.assertRaisesRegex(ValueError, "columns overlap but no suffix specified"):
ks1 = ks.Series(["A1", "A5"], index=[1, 2], name="A")
kdf1.join(ks1, how="outer")
# join with duplicated columns in DataFrame
with self.assertRaisesRegex(ValueError, "columns overlap but no suffix specified"):
Expand All @@ -1606,6 +1603,17 @@ def test_join(self):
join_kdf.sort_values(by=list(join_kdf.columns), inplace=True)
self.assert_eq(join_pdf.reset_index(drop=True), join_kdf.reset_index(drop=True))

join_pdf = pdf1.set_index("key").join(
pdf2.set_index("key"), on="key", lsuffix="_left", rsuffix="_right"
)
join_pdf.sort_values(by=list(join_pdf.columns), inplace=True)

join_kdf = kdf1.set_index("key").join(
kdf2.set_index("key"), on="key", lsuffix="_left", rsuffix="_right"
)
join_kdf.sort_values(by=list(join_kdf.columns), inplace=True)
self.assert_eq(join_pdf.reset_index(drop=True), join_kdf.reset_index(drop=True))

# multi-index columns
columns1 = pd.MultiIndex.from_tuples([("x", "key"), ("Y", "A")])
columns2 = pd.MultiIndex.from_tuples([("x", "key"), ("Y", "B")])
Expand Down Expand Up @@ -1635,6 +1643,43 @@ def test_join(self):

self.assert_eq(join_pdf.reset_index(drop=True), join_kdf.reset_index(drop=True))

join_pdf = pdf1.set_index(("x", "key")).join(
pdf2.set_index(("x", "key")), on=[("x", "key")], lsuffix="_left", rsuffix="_right"
)
join_pdf.sort_values(by=list(join_pdf.columns), inplace=True)

join_kdf = kdf1.set_index(("x", "key")).join(
kdf2.set_index(("x", "key")), on=[("x", "key")], lsuffix="_left", rsuffix="_right"
)
join_kdf.sort_values(by=list(join_kdf.columns), inplace=True)

self.assert_eq(join_pdf.reset_index(drop=True), join_kdf.reset_index(drop=True))

# multi-index
midx1 = pd.MultiIndex.from_tuples(
[("w", "a"), ("x", "b"), ("y", "c"), ("z", "d")], names=["index1", "index2"]
)
midx2 = pd.MultiIndex.from_tuples(
[("w", "a"), ("x", "b"), ("y", "c")], names=["index1", "index2"]
)
pdf1.index = midx1
pdf2.index = midx2
kdf1 = ks.from_pandas(pdf1)
kdf2 = ks.from_pandas(pdf2)

join_pdf = pdf1.join(pdf2, on=["index1", "index2"], rsuffix="_right")
join_pdf.sort_values(by=list(join_pdf.columns), inplace=True)

join_kdf = kdf1.join(kdf2, on=["index1", "index2"], rsuffix="_right")
join_kdf.sort_values(by=list(join_kdf.columns), inplace=True)

self.assert_eq(join_pdf, join_kdf)

with self.assertRaisesRegex(
ValueError, r'len\(left_on\) must equal the number of levels in the index of "right"'
):
kdf1.join(kdf2, on=["index1"], rsuffix="_right")

def test_replace(self):
pdf = pd.DataFrame(
{
Expand Down