From c7a1ad7156be6019e22bca977e28a5326fdd39ec Mon Sep 17 00:00:00 2001 From: itholic Date: Fri, 19 Feb 2021 20:03:31 +0900 Subject: [PATCH 1/4] Fix DataFrame.merge to work properly --- databricks/koalas/frame.py | 40 +++++++++++++++++-- databricks/koalas/tests/test_dataframe.py | 48 +++++++++++++++++++++++ 2 files changed, 85 insertions(+), 3 deletions(-) diff --git a/databricks/koalas/frame.py b/databricks/koalas/frame.py index 492599d5d3..f707ce02db 100644 --- a/databricks/koalas/frame.py +++ b/databricks/koalas/frame.py @@ -7260,10 +7260,37 @@ def to_list(os: Optional[Union[Any, List[Any], Tuple, List[Tuple]]]) -> List[Tup if len(left_key_names) != len(right_key_names): raise ValueError("len(left_keys) must equal len(right_keys)") + # We should distinguish the name to avoid ambiguous column name after merging. + right_prefix = "__right_" + right_key_names = [right_prefix + right_key_name for right_key_name in right_key_names] + how = validate_how(how) + def resolve(internal, side): + rename = lambda col: "__{}_{}".format(side, col) + internal = internal.resolved_copy + sdf = internal.spark_frame + sdf = internal.spark_frame.select( + [ + scol_for(sdf, col).alias(rename(col)) + for col in sdf.columns + if col not in HIDDEN_COLUMNS + ] + + list(HIDDEN_COLUMNS) + ) + return internal.copy( + spark_frame=sdf, + index_spark_columns=[ + scol_for(sdf, rename(col)) for col in internal.index_spark_column_names + ], + data_spark_columns=[ + scol_for(sdf, rename(col)) for col in internal.data_spark_column_names + ], + preserve_dtypes=True, + ) + left_internal = self._internal.resolved_copy - right_internal = right._internal.resolved_copy + right_internal = resolve(right._internal, "right") left_table = left_internal.spark_frame.alias("left_table") right_table = right_internal.spark_frame.alias("right_table") @@ -7301,9 +7328,13 @@ def to_list(os: Optional[Union[Any, List[Any], Tuple, List[Tuple]]]) -> List[Tup scol = left_scol_for(label) if label in duplicate_columns: spark_column_name = left_internal.spark_column_name_for(label) - if spark_column_name in left_key_names and spark_column_name in right_key_names: + if ( + spark_column_name in left_key_names + and (right_prefix + spark_column_name) in right_key_names + ): right_scol = right_scol_for(label) if how == "right": + col = right_prefix + col scol = right_scol elif how == "full": scol = F.when(scol.isNotNull(), scol).otherwise(right_scol).alias(col) @@ -7321,7 +7352,10 @@ def to_list(os: Optional[Union[Any, List[Any], Tuple, List[Tuple]]]) -> List[Tup scol = right_scol_for(label) if label in duplicate_columns: spark_column_name = left_internal.spark_column_name_for(label) - if spark_column_name in left_key_names and spark_column_name in right_key_names: + if ( + spark_column_name in left_key_names + and (right_prefix + spark_column_name) in right_key_names + ): continue else: col = col + right_suffix diff --git a/databricks/koalas/tests/test_dataframe.py b/databricks/koalas/tests/test_dataframe.py index 8dfa0d9fc4..1555781cb0 100644 --- a/databricks/koalas/tests/test_dataframe.py +++ b/databricks/koalas/tests/test_dataframe.py @@ -2030,6 +2030,54 @@ def check(op, right_kdf=right_kdf, right_pdf=right_pdf): ) ) + def test_merge_same_anchor(self): + pdf = pd.DataFrame( + { + "lkey": ["foo", "bar", "baz", "foo", "bar", "l"], + "rkey": ["baz", "foo", "bar", "baz", "foo", "r"], + "value": [1, 1, 3, 5, 6, 7], + "x": list("abcdef"), + "y": list("efghij"), + }, + columns=["lkey", "rkey", "value", "x", "y"], + ) + kdf = ks.from_pandas(pdf) + + left_pdf = pdf[["lkey", "value", "x"]] + right_pdf = pdf[["rkey", "value", "y"]] + left_kdf = kdf[["lkey", "value", "x"]] + right_kdf = kdf[["rkey", "value", "y"]] + + def check(op, right_kdf=right_kdf, right_pdf=right_pdf): + k_res = op(left_kdf, right_kdf) + k_res = k_res.to_pandas() + k_res = k_res.sort_values(by=list(k_res.columns)) + k_res = k_res.reset_index(drop=True) + p_res = op(left_pdf, right_pdf) + p_res = p_res.sort_values(by=list(p_res.columns)) + p_res = p_res.reset_index(drop=True) + self.assert_eq(k_res, p_res) + + check(lambda left, right: left.merge(right)) + check(lambda left, right: left.merge(right, on="value")) + check(lambda left, right: left.merge(right, left_on="lkey", right_on="rkey")) + check(lambda left, right: left.set_index("lkey").merge(right.set_index("rkey"))) + check( + lambda left, right: left.set_index("lkey").merge( + right, left_index=True, right_on="rkey" + ) + ) + check( + lambda left, right: left.merge( + right.set_index("rkey"), left_on="lkey", right_index=True + ) + ) + check( + lambda left, right: left.set_index("lkey").merge( + right.set_index("rkey"), left_index=True, right_index=True + ) + ) + def test_merge_retains_indices(self): left_pdf = pd.DataFrame({"A": [0, 1]}) right_pdf = pd.DataFrame({"B": [1, 2]}, index=[1, 2]) From b2609427efa38243a9b10b55eeebac764317b687 Mon Sep 17 00:00:00 2001 From: itholic Date: Fri, 19 Feb 2021 21:58:24 +0900 Subject: [PATCH 2/4] preserve the Spark column name --- databricks/koalas/frame.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/databricks/koalas/frame.py b/databricks/koalas/frame.py index f707ce02db..944929b649 100644 --- a/databricks/koalas/frame.py +++ b/databricks/koalas/frame.py @@ -7358,7 +7358,8 @@ def resolve(internal, side): ): continue else: - col = col + right_suffix + # remove `right_prefix` here. + col = (col + right_suffix)[len(right_prefix) :] scol = scol.alias(col) label = tuple([str(label[0]) + right_suffix] + list(label[1:])) exprs.append(scol) From 3a051f93649e4d2a92e34bf89d6df99d4e164f8e Mon Sep 17 00:00:00 2001 From: itholic Date: Sun, 21 Feb 2021 18:37:02 +0900 Subject: [PATCH 3/4] Addressed comments --- databricks/koalas/frame.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/databricks/koalas/frame.py b/databricks/koalas/frame.py index 944929b649..10d40ae519 100644 --- a/databricks/koalas/frame.py +++ b/databricks/koalas/frame.py @@ -7334,8 +7334,7 @@ def resolve(internal, side): ): right_scol = right_scol_for(label) if how == "right": - col = right_prefix + col - scol = right_scol + scol = right_scol.alias(col) elif how == "full": scol = F.when(scol.isNotNull(), scol).otherwise(right_scol).alias(col) else: @@ -7348,8 +7347,9 @@ def resolve(internal, side): data_columns.append(col) column_labels.append(label) for label in right_internal.column_labels: - col = right_internal.spark_column_name_for(label) - scol = right_scol_for(label) + # recover `right_prefix` here. + col = right_internal.spark_column_name_for(label)[len(right_prefix) :] + scol = right_scol_for(label).alias(col) if label in duplicate_columns: spark_column_name = left_internal.spark_column_name_for(label) if ( @@ -7358,8 +7358,7 @@ def resolve(internal, side): ): continue else: - # remove `right_prefix` here. - col = (col + right_suffix)[len(right_prefix) :] + col = col + right_suffix scol = scol.alias(col) label = tuple([str(label[0]) + right_suffix] + list(label[1:])) exprs.append(scol) From 16c664d219faddd876a32a2e9abbcb39aa1b3834 Mon Sep 17 00:00:00 2001 From: itholic Date: Mon, 22 Feb 2021 10:30:04 +0900 Subject: [PATCH 4/4] Retrigger the test