From e2d4c4b0869115463039d5728f5a512cc9fdd2da Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 20 Sep 2019 17:46:42 -0700 Subject: [PATCH 1/2] Fix merge to partially support multi-index columns. --- databricks/koalas/frame.py | 123 +++++++++++----------- databricks/koalas/internal.py | 4 +- databricks/koalas/tests/test_dataframe.py | 25 ++++- 3 files changed, 87 insertions(+), 65 deletions(-) diff --git a/databricks/koalas/frame.py b/databricks/koalas/frame.py index c732803df8..7783d305a7 100644 --- a/databricks/koalas/frame.py +++ b/databricks/koalas/frame.py @@ -5077,9 +5077,9 @@ def shape(self): # TODO: support multi-index columns def merge(self, right: 'DataFrame', how: str = 'inner', - on: Optional[Union[str, List[str]]] = None, - left_on: Optional[Union[str, List[str]]] = None, - right_on: Optional[Union[str, List[str]]] = None, + on: Union[str, List[str], Tuple[str, ...], List[Tuple[str, ...]]] = None, + left_on: Union[str, List[str], Tuple[str, ...], List[Tuple[str, ...]]] = None, + right_on: Union[str, List[str], Tuple[str, ...], List[Tuple[str, ...]]] = None, left_index: bool = False, right_index: bool = False, suffixes: Tuple[str, str] = ('_x', '_y')) -> 'DataFrame': """ @@ -5191,7 +5191,11 @@ def merge(self, right: 'DataFrame', how: str = 'inner', As described in #263, joining string columns currently returns None for missing values instead of NaN. """ - _to_list = lambda o: o if o is None or is_list_like(o) else [o] + _to_list = lambda os: (os if os is None + else [os] if isinstance(os, tuple) + else [(os,)] if isinstance(os, str) + else [o if isinstance(o, tuple) else (o,) # type: ignore + for o in os]) if on: if left_on or right_on: @@ -5220,8 +5224,8 @@ def merge(self, right: 'DataFrame', how: str = 'inner', raise ValueError( 'No common columns to perform merge on. Merge options: ' 'left_on=None, right_on=None, left_index=False, right_index=False') - left_keys = common - right_keys = common + left_keys = _to_list(common) + right_keys = _to_list(common) if len(left_keys) != len(right_keys): # type: ignore raise ValueError('len(left_keys) must equal len(right_keys)') @@ -5235,11 +5239,14 @@ def merge(self, right: 'DataFrame', how: str = 'inner', raise ValueError("The 'how' parameter has to be amongst the following values: ", "['inner', 'left', 'right', 'outer']") - left_table = self._internal.spark_internal_df.alias('left_table') - right_table = right._internal.spark_internal_df.alias('right_table') + left_table = self._sdf.alias('left_table') + right_table = right._sdf.alias('right_table') - left_key_columns = [scol_for(left_table, col) for col in left_keys] # type: ignore - right_key_columns = [scol_for(right_table, col) for col in right_keys] # type: ignore + left_scol_for = lambda idx: scol_for(left_table, self._internal.column_name_for(idx)) + right_scol_for = lambda idx: scol_for(right_table, right._internal.column_name_for(idx)) + + left_key_columns = [left_scol_for(idx) for idx in left_keys] # type: ignore + right_key_columns = [right_scol_for(idx) for idx in right_keys] # type: ignore join_condition = reduce(lambda x, y: x & y, [lkey == rkey for lkey, rkey @@ -5252,20 +5259,18 @@ def merge(self, right: 'DataFrame', how: str = 'inner', right_suffix = suffixes[1] # Append suffixes to columns with the same name to avoid conflicts later - duplicate_columns = (set(self._internal.data_columns) - & set(right._internal.data_columns)) - - left_index_columns = set(self._internal.index_columns) - right_index_columns = set(right._internal.index_columns) + duplicate_columns = (set(self._internal.column_index) + & set(right._internal.column_index)) exprs = [] - for col in left_table.columns: - if col in left_index_columns: - continue - scol = scol_for(left_table, col) - if col in duplicate_columns: - if col in left_keys and col in right_keys: - right_scol = scol_for(right_table, col) + data_columns = [] + column_index = [] + for idx in self._internal.column_index: + col = self._internal.column_name_for(idx) + scol = left_scol_for(idx) + if idx in duplicate_columns: + if idx in left_keys and idx in right_keys: # type: ignore + right_scol = right_scol_for(idx) if how == 'right': scol = right_scol elif how == 'full': @@ -5275,64 +5280,60 @@ def merge(self, right: 'DataFrame', how: str = 'inner', else: col = col + left_suffix scol = scol.alias(col) + idx = tuple([idx[0] + left_suffix] + list(idx[1:])) exprs.append(scol) - for col in right_table.columns: - if col in right_index_columns: - continue - scol = scol_for(right_table, col) - if col in duplicate_columns: - if col in left_keys and col in right_keys: + data_columns.append(col) + column_index.append(idx) + for idx in right._internal.column_index: + col = right._internal.column_name_for(idx) + scol = right_scol_for(idx) + if idx in duplicate_columns: + if idx in left_keys and idx in right_keys: # type: ignore continue else: col = col + right_suffix scol = scol.alias(col) + idx = tuple([idx[0] + right_suffix] + list(idx[1:])) exprs.append(scol) + data_columns.append(col) + column_index.append(idx) + + left_index_scols = self._internal.index_scols + right_index_scols = right._internal.index_scols # Retain indices if they are used for joining if left_index: if right_index: - exprs.extend(['left_table.`{}`'.format(col) for col in left_index_columns]) - exprs.extend(['right_table.`{}`'.format(col) for col in right_index_columns]) - index_map = self._internal.index_map + [idx for idx in right._internal.index_map - if idx not in self._internal.index_map] + if how in ('inner', 'left'): + exprs.extend(left_index_scols) + index_map = self._internal.index_map + elif how == 'right': + exprs.extend(right_index_scols) + index_map = right._internal.index_map + else: + index_map = [] + for (col, name), left_scol, right_scol in zip(self._internal.index_map, + left_index_scols, + right_index_scols): + scol = F.when(left_scol.isNotNull(), left_scol).otherwise(right_scol) + exprs.append(scol.alias(col)) + index_map.append((col, name)) else: - exprs.extend(['right_table.`{}`'.format(col) for col in right_index_columns]) + exprs.extend(right_index_scols) index_map = right._internal.index_map elif right_index: - exprs.extend(['left_table.`{}`'.format(col) for col in left_index_columns]) + exprs.extend(left_index_scols) index_map = self._internal.index_map else: index_map = [] selected_columns = joined_table.select(*exprs) - # Merge left and right indices after the join by replacing missing values in the left index - # with values from the right index and dropping - if (how == 'right' or how == 'full') and right_index: - for left_index_col, right_index_col in zip(self._internal.index_columns, - right._internal.index_columns): - selected_columns = selected_columns.withColumn( - 'left_table.' + left_index_col, - F.when(F.col('left_table.`{}`'.format(left_index_col)).isNotNull(), - F.col('left_table.`{}`'.format(left_index_col))) - .otherwise(F.col('right_table.`{}`'.format(right_index_col))) - ).withColumnRenamed( - 'left_table.' + left_index_col, left_index_col - ).drop(F.col('left_table.`{}`'.format(left_index_col))) - if not (left_index and not right_index): - for right_index_col in right_index_columns: - if right_index_col in left_index_columns: - selected_columns = \ - selected_columns.drop(F.col('right_table.`{}`'.format(right_index_col))) - - if index_map: - data_columns = [c for c in selected_columns.columns - if c not in [idx[0] for idx in index_map]] - internal = _InternalFrame( - sdf=selected_columns, data_columns=data_columns, index_map=index_map) - return DataFrame(internal) - else: - return DataFrame(selected_columns) + internal = _InternalFrame(sdf=selected_columns, + index_map=index_map if index_map else None, + data_columns=data_columns, + column_index=column_index) + return DataFrame(internal) def join(self, right: 'DataFrame', on: Optional[Union[str, List[str]]] = None, how: str = 'left', lsuffix: str = '', rsuffix: str = '') -> 'DataFrame': diff --git a/databricks/koalas/internal.py b/databricks/koalas/internal.py index 500a1d61a4..3df616c6f1 100644 --- a/databricks/koalas/internal.py +++ b/databricks/koalas/internal.py @@ -513,8 +513,8 @@ def _column_index_map(self) -> Dict[Tuple[str, ...], str]: def column_name_for(self, column_name_or_index: Union[str, Tuple[str, ...]]) -> str: """ Return the actual Spark column name for the given column name or index. """ if column_name_or_index not in self._column_index_map: - # TODO: assert column_name_or_index not in self.data_columns - assert isinstance(column_name_or_index, str), column_name_or_index + if not isinstance(column_name_or_index, str): + raise KeyError(column_name_or_index) return column_name_or_index else: return self._column_index_map[column_name_or_index] diff --git a/databricks/koalas/tests/test_dataframe.py b/databricks/koalas/tests/test_dataframe.py index 7e415e7a14..f7f07c7c77 100644 --- a/databricks/koalas/tests/test_dataframe.py +++ b/databricks/koalas/tests/test_dataframe.py @@ -764,6 +764,28 @@ def check(op): check(lambda left, right: left.merge(right, left_on='lkey', right_on='rkey', suffixes=['_left', '_right'])) + # multi-index columns + left_columns = pd.MultiIndex.from_tuples([('a', 'lkey'), ('a', 'value'), ('b', 'x')]) + left_pdf.columns = left_columns + left_kdf.columns = left_columns + + right_columns = pd.MultiIndex.from_tuples([('a', 'rkey'), ('a', 'value'), ('c', 'y')]) + right_pdf.columns = right_columns + right_kdf.columns = right_columns + + check(lambda left, right: left.merge(right)) + check(lambda left, right: left.merge(right, on=[('a', 'value')])) + check(lambda left, right: (left.set_index(('a', 'lkey')) + .merge(right.set_index(('a', 'rkey'))))) + check(lambda left, right: (left.set_index(('a', 'lkey')) + .merge(right.set_index(('a', 'rkey')), + left_index=True, right_index=True))) + # TODO: when both left_index=True and right_index=True with multi-index columns + # check(lambda left, right: left.merge(right, + # left_on=[('a', 'lkey')], right_on=[('a', 'rkey')])) + # check(lambda left, right: (left.set_index(('a', 'lkey')) + # .merge(right, left_index=True, right_on=[('a', 'rkey')]))) + def test_merge_retains_indices(self): left_pdf = pd.DataFrame({'A': [0, 1]}) right_pdf = pd.DataFrame({'B': [1, 2]}, index=[1, 2]) @@ -840,8 +862,7 @@ def test_merge_raises(self): "['inner', 'left', 'right', 'full', 'outer']"): left.merge(right, left_index=True, right_index=True, how='foo') - with self.assertRaisesRegex(AnalysisException, - 'Cannot resolve column name "`id`"'): + with self.assertRaisesRegex(KeyError, 'id'): left.merge(right, on='id') def test_append(self): From 2fb33082c64c814c96378f8e4b5e24e7c68bb2df Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Tue, 24 Sep 2019 15:14:24 -0700 Subject: [PATCH 2/2] Fix. --- databricks/koalas/frame.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/databricks/koalas/frame.py b/databricks/koalas/frame.py index 2621f472c1..5d1edd846b 100644 --- a/databricks/koalas/frame.py +++ b/databricks/koalas/frame.py @@ -5194,6 +5194,9 @@ def merge(self, right: 'DataFrame', how: str = 'inner', else [o if isinstance(o, tuple) else (o,) # type: ignore for o in os]) + if isinstance(right, ks.Series): + right = right.to_frame() + if on: if left_on or right_on: raise ValueError('Can only pass argument "on" OR "left_on" and "right_on", ' @@ -5216,10 +5219,7 @@ def merge(self, right: 'DataFrame', how: str = 'inner', if right_keys and not left_keys: raise ValueError('Must pass left_on or left_index=True') if not left_keys and not right_keys: - if isinstance(right, ks.Series): - common = list(self.columns.intersection([right.name])) - else: - common = list(self.columns.intersection(right.columns)) + common = list(self.columns.intersection(right.columns)) if len(common) == 0: raise ValueError( 'No common columns to perform merge on. Merge options: '