Skip to content

Commit

Permalink
Fix utils.combine_frames for Series with different level of indexes…
Browse files Browse the repository at this point in the history
…. (#926)

This fixes the bug at databricks/koalas#923 (comment).
  • Loading branch information
rising-star92 committed Oct 15, 2019
1 parent 021eaf3 commit 0a5e04c
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
8 changes: 8 additions & 0 deletions databricks/koalas/tests/test_ops_on_diff_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,14 @@ def test_arithmetic(self):
(kdf1[('x', 'a')] - kdf2[('x', 'b')]).sort_index(),
(pdf1[('x', 'a')] - pdf2[('x', 'b')]).rename(('x', 'a')).sort_index(), almost=True)

self.assert_eq(
(kdf1[('x', 'a')] - kdf2['x']['b']).sort_index(),
(pdf1[('x', 'a')] - pdf2['x']['b']).rename(('x', 'a')).sort_index(), almost=True)

self.assert_eq(
(kdf1['x']['a'] - kdf2[('x', 'b')]).sort_index(),
(pdf1['x']['a'] - pdf2[('x', 'b')]).rename('a').sort_index(), almost=True)

# DataFrame
self.assert_eq(
(kdf1 + kdf2).sort_index(),
Expand Down
14 changes: 9 additions & 5 deletions databricks/koalas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,13 @@ def combine_frames(this, *args, how="full"):

index_columns = set(this._internal.index_columns)
new_data_columns = [c for c in joined_df.columns if c not in index_columns]
column_index = ([tuple(['this', *idx]) for idx in this._internal.column_index]
+ [tuple(['that', *idx]) for idx in that._internal.column_index])
column_index_names = (([None] + this._internal.column_index_names)
level = max(this._internal.column_index_level, that._internal.column_index_level)
column_index = ([tuple(['this'] + ([''] * (level - len(idx))) + list(idx))
for idx in this._internal.column_index]
+ [tuple(['that'] + ([''] * (level - len(idx))) + list(idx))
for idx in that._internal.column_index])
column_index_names = ((([None] * (1 + level - len(this._internal.column_index_level)))
+ this._internal.column_index_names)
if this._internal.column_index_names is not None else None)
return DataFrame(
this._internal.copy(sdf=joined_df, data_columns=new_data_columns,
Expand Down Expand Up @@ -249,10 +253,10 @@ def align_diff_series(func, this_series, *args, how="full"):
cols = [arg for arg in args if isinstance(arg, IndexOpsMixin)]
combined = combine_frames(this_series.to_frame(), *cols, how=how)

that_columns = [combined[tuple(['that', *arg._internal.column_index[0]])]._scol
that_columns = [combined['that'][arg._internal.column_index[0]]._scol
if isinstance(arg, IndexOpsMixin) else arg for arg in args]

scol = func(combined[tuple(['this', *this_series._internal.column_index[0]])]._scol,
scol = func(combined['this'][this_series._internal.column_index[0]]._scol,
*that_columns)

return Series(combined._internal.copy(scol=scol,
Expand Down

0 comments on commit 0a5e04c

Please sign in to comment.