Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix merge to support multi-index columns. #825

Merged
merged 3 commits into from
Sep 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 66 additions & 65 deletions databricks/koalas/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -5074,9 +5074,9 @@ def shape(self):

# TODO: support multi-index columns
def merge(self, right: 'DataFrame', how: str = 'inner',
on: Optional[Union[str, List[str]]] = None,
left_on: Optional[Union[str, List[str]]] = None,
right_on: Optional[Union[str, List[str]]] = None,
on: Union[str, List[str], Tuple[str, ...], List[Tuple[str, ...]]] = None,
left_on: Union[str, List[str], Tuple[str, ...], List[Tuple[str, ...]]] = None,
right_on: Union[str, List[str], Tuple[str, ...], List[Tuple[str, ...]]] = None,
left_index: bool = False, right_index: bool = False,
suffixes: Tuple[str, str] = ('_x', '_y')) -> 'DataFrame':
"""
Expand Down Expand Up @@ -5188,7 +5188,14 @@ def merge(self, right: 'DataFrame', how: str = 'inner',
As described in #263, joining string columns currently returns None for missing values
instead of NaN.
"""
_to_list = lambda o: o if o is None or is_list_like(o) else [o]
_to_list = lambda os: (os if os is None
else [os] if isinstance(os, tuple)
else [(os,)] if isinstance(os, str)
else [o if isinstance(o, tuple) else (o,) # type: ignore
for o in os])

if isinstance(right, ks.Series):
right = right.to_frame()

if on:
if left_on or right_on:
Expand All @@ -5212,16 +5219,13 @@ def merge(self, right: 'DataFrame', how: str = 'inner',
if right_keys and not left_keys:
raise ValueError('Must pass left_on or left_index=True')
if not left_keys and not right_keys:
if isinstance(right, ks.Series):
common = list(self.columns.intersection([right.name]))
else:
common = list(self.columns.intersection(right.columns))
common = list(self.columns.intersection(right.columns))
if len(common) == 0:
raise ValueError(
'No common columns to perform merge on. Merge options: '
'left_on=None, right_on=None, left_index=False, right_index=False')
left_keys = common
right_keys = common
left_keys = _to_list(common)
right_keys = _to_list(common)
if len(left_keys) != len(right_keys): # type: ignore
raise ValueError('len(left_keys) must equal len(right_keys)')

Expand All @@ -5235,11 +5239,14 @@ def merge(self, right: 'DataFrame', how: str = 'inner',
raise ValueError("The 'how' parameter has to be amongst the following values: ",
"['inner', 'left', 'right', 'outer']")

left_table = self._internal.spark_internal_df.alias('left_table')
right_table = right._internal.spark_internal_df.alias('right_table')
left_table = self._sdf.alias('left_table')
right_table = right._sdf.alias('right_table')

left_scol_for = lambda idx: scol_for(left_table, self._internal.column_name_for(idx))
right_scol_for = lambda idx: scol_for(right_table, right._internal.column_name_for(idx))

left_key_columns = [scol_for(left_table, col) for col in left_keys] # type: ignore
right_key_columns = [scol_for(right_table, col) for col in right_keys] # type: ignore
left_key_columns = [left_scol_for(idx) for idx in left_keys] # type: ignore
right_key_columns = [right_scol_for(idx) for idx in right_keys] # type: ignore

join_condition = reduce(lambda x, y: x & y,
[lkey == rkey for lkey, rkey
Expand All @@ -5252,20 +5259,18 @@ def merge(self, right: 'DataFrame', how: str = 'inner',
right_suffix = suffixes[1]

# Append suffixes to columns with the same name to avoid conflicts later
duplicate_columns = (set(self._internal.data_columns)
& set(right._internal.data_columns))

left_index_columns = set(self._internal.index_columns)
right_index_columns = set(right._internal.index_columns)
duplicate_columns = (set(self._internal.column_index)
& set(right._internal.column_index))

exprs = []
for col in left_table.columns:
if col in left_index_columns:
continue
scol = scol_for(left_table, col)
if col in duplicate_columns:
if col in left_keys and col in right_keys:
right_scol = scol_for(right_table, col)
data_columns = []
column_index = []
for idx in self._internal.column_index:
col = self._internal.column_name_for(idx)
scol = left_scol_for(idx)
if idx in duplicate_columns:
if idx in left_keys and idx in right_keys: # type: ignore
right_scol = right_scol_for(idx)
if how == 'right':
scol = right_scol
elif how == 'full':
Expand All @@ -5275,64 +5280,60 @@ def merge(self, right: 'DataFrame', how: str = 'inner',
else:
col = col + left_suffix
scol = scol.alias(col)
idx = tuple([idx[0] + left_suffix] + list(idx[1:]))
exprs.append(scol)
for col in right_table.columns:
if col in right_index_columns:
continue
scol = scol_for(right_table, col)
if col in duplicate_columns:
if col in left_keys and col in right_keys:
data_columns.append(col)
column_index.append(idx)
for idx in right._internal.column_index:
col = right._internal.column_name_for(idx)
scol = right_scol_for(idx)
if idx in duplicate_columns:
if idx in left_keys and idx in right_keys: # type: ignore
continue
else:
col = col + right_suffix
scol = scol.alias(col)
idx = tuple([idx[0] + right_suffix] + list(idx[1:]))
exprs.append(scol)
data_columns.append(col)
column_index.append(idx)

left_index_scols = self._internal.index_scols
right_index_scols = right._internal.index_scols

# Retain indices if they are used for joining
if left_index:
if right_index:
exprs.extend(['left_table.`{}`'.format(col) for col in left_index_columns])
exprs.extend(['right_table.`{}`'.format(col) for col in right_index_columns])
index_map = self._internal.index_map + [idx for idx in right._internal.index_map
if idx not in self._internal.index_map]
if how in ('inner', 'left'):
exprs.extend(left_index_scols)
index_map = self._internal.index_map
elif how == 'right':
exprs.extend(right_index_scols)
index_map = right._internal.index_map
else:
index_map = []
for (col, name), left_scol, right_scol in zip(self._internal.index_map,
left_index_scols,
right_index_scols):
scol = F.when(left_scol.isNotNull(), left_scol).otherwise(right_scol)
exprs.append(scol.alias(col))
index_map.append((col, name))
else:
exprs.extend(['right_table.`{}`'.format(col) for col in right_index_columns])
exprs.extend(right_index_scols)
index_map = right._internal.index_map
elif right_index:
exprs.extend(['left_table.`{}`'.format(col) for col in left_index_columns])
exprs.extend(left_index_scols)
index_map = self._internal.index_map
else:
index_map = []

selected_columns = joined_table.select(*exprs)

# Merge left and right indices after the join by replacing missing values in the left index
# with values from the right index and dropping
if (how == 'right' or how == 'full') and right_index:
for left_index_col, right_index_col in zip(self._internal.index_columns,
right._internal.index_columns):
selected_columns = selected_columns.withColumn(
'left_table.' + left_index_col,
F.when(F.col('left_table.`{}`'.format(left_index_col)).isNotNull(),
F.col('left_table.`{}`'.format(left_index_col)))
.otherwise(F.col('right_table.`{}`'.format(right_index_col)))
).withColumnRenamed(
'left_table.' + left_index_col, left_index_col
).drop(F.col('left_table.`{}`'.format(left_index_col)))
if not (left_index and not right_index):
for right_index_col in right_index_columns:
if right_index_col in left_index_columns:
selected_columns = \
selected_columns.drop(F.col('right_table.`{}`'.format(right_index_col)))

if index_map:
data_columns = [c for c in selected_columns.columns
if c not in [idx[0] for idx in index_map]]
internal = _InternalFrame(
sdf=selected_columns, data_columns=data_columns, index_map=index_map)
return DataFrame(internal)
else:
return DataFrame(selected_columns)
internal = _InternalFrame(sdf=selected_columns,
index_map=index_map if index_map else None,
data_columns=data_columns,
column_index=column_index)
return DataFrame(internal)

def join(self, right: 'DataFrame', on: Optional[Union[str, List[str]]] = None,
how: str = 'left', lsuffix: str = '', rsuffix: str = '') -> 'DataFrame':
Expand Down
4 changes: 2 additions & 2 deletions databricks/koalas/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,8 +513,8 @@ def _column_index_map(self) -> Dict[Tuple[str, ...], str]:
def column_name_for(self, column_name_or_index: Union[str, Tuple[str, ...]]) -> str:
""" Return the actual Spark column name for the given column name or index. """
if column_name_or_index not in self._column_index_map:
# TODO: assert column_name_or_index not in self.data_columns
assert isinstance(column_name_or_index, str), column_name_or_index
if not isinstance(column_name_or_index, str):
raise KeyError(column_name_or_index)
return column_name_or_index
else:
return self._column_index_map[column_name_or_index]
Expand Down
32 changes: 26 additions & 6 deletions databricks/koalas/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@
# limitations under the License.
#

from datetime import date, datetime
import inspect
from datetime import datetime
from distutils.version import LooseVersion
import inspect

import numpy as np
import pandas as pd
from pyspark.sql.utils import AnalysisException

from databricks import koalas as ks
from databricks.koalas.config import set_option, reset_option
Expand Down Expand Up @@ -774,7 +773,7 @@ def check(op, right_kdf=right_kdf, right_pdf=right_pdf):
check(lambda left, right: left.merge(right, left_on='x', right_on='x'),
right_ks, right_ps)
check(lambda left, right: left.set_index('x').merge(right, left_index=True,
right_on='x'), right_ks, right_ps)
right_on='x'), right_ks, right_ps)

# Test join types with Series
for how in ['inner', 'left', 'right', 'outer']:
Expand All @@ -787,6 +786,28 @@ def check(op, right_kdf=right_kdf, right_pdf=right_pdf):
left_index=True, right_index=True),
right_ks, right_ps)

# multi-index columns
left_columns = pd.MultiIndex.from_tuples([('a', 'lkey'), ('a', 'value'), ('b', 'x')])
left_pdf.columns = left_columns
left_kdf.columns = left_columns

right_columns = pd.MultiIndex.from_tuples([('a', 'rkey'), ('a', 'value'), ('c', 'y')])
right_pdf.columns = right_columns
right_kdf.columns = right_columns

check(lambda left, right: left.merge(right))
check(lambda left, right: left.merge(right, on=[('a', 'value')]))
check(lambda left, right: (left.set_index(('a', 'lkey'))
.merge(right.set_index(('a', 'rkey')))))
check(lambda left, right: (left.set_index(('a', 'lkey'))
.merge(right.set_index(('a', 'rkey')),
left_index=True, right_index=True)))
# TODO: when both left_index=True and right_index=True with multi-index columns
# check(lambda left, right: left.merge(right,
# left_on=[('a', 'lkey')], right_on=[('a', 'rkey')]))
# check(lambda left, right: (left.set_index(('a', 'lkey'))
# .merge(right, left_index=True, right_on=[('a', 'rkey')])))

def test_merge_retains_indices(self):
left_pdf = pd.DataFrame({'A': [0, 1]})
right_pdf = pd.DataFrame({'B': [1, 2]}, index=[1, 2])
Expand Down Expand Up @@ -863,8 +884,7 @@ def test_merge_raises(self):
"['inner', 'left', 'right', 'full', 'outer']"):
left.merge(right, left_index=True, right_index=True, how='foo')

with self.assertRaisesRegex(AnalysisException,
'Cannot resolve column name "`id`"'):
with self.assertRaisesRegex(KeyError, 'id'):
left.merge(right, on='id')

def test_append(self):
Expand Down