Skip to content

Commit 05f9fcd

Browse files
committed
Cleanup and add some more tests
1 parent e5ee10c commit 05f9fcd

File tree

3 files changed

+124
-28
lines changed

3 files changed

+124
-28
lines changed

databricks/koalas/frame.py

+7-11
Original file line numberDiff line numberDiff line change
@@ -400,8 +400,7 @@ def apply_op(kdf, this_columns, that_columns):
400400
for this_column, that_column in zip(this_columns, that_columns):
401401
yield getattr(kdf[this_column], op)(kdf[that_column])
402402

403-
return align_diff_frames(
404-
apply_op, self, other, fillna=True, how="full", include_all_that_columns=False)
403+
return align_diff_frames(apply_op, self, other, fillna=True, how="full")
405404
elif isinstance(other, DataFrame) and self is not other:
406405
# Same DataFrames
407406
for column in self._internal.data_columns:
@@ -6354,9 +6353,9 @@ def __getitem__(self, key):
63546353
def __setitem__(self, key, value):
63556354
from databricks.koalas.series import Series
63566355

6357-
if ((isinstance(value, Series) and value._kdf is not self) or
6358-
(isinstance(value, DataFrame) and value is not self)):
6359-
# Different (anchor) DataFrames
6356+
if (isinstance(value, Series) and value._kdf is not self) or \
6357+
(isinstance(value, DataFrame) and value is not self):
6358+
# Different Series or DataFrames
63606359
if isinstance(value, Series):
63616360
value = value.to_frame()
63626361

@@ -6369,20 +6368,17 @@ def assign_columns(kdf, this_columns, that_columns):
63696368
# that_columns.
63706369
for k, this_column, that_column in zip_longest(key, this_columns, that_columns):
63716370
yield kdf[that_column].rename(k)
6372-
if this_column is not None:
6373-
# if both're same columns first one is higher priority.
6371+
if this_column != k and this_column is not None:
63746372
yield kdf[this_column]
63756373

6376-
kdf = align_diff_frames(
6377-
assign_columns, self, value, fillna=False,
6378-
how="left", include_all_that_columns=True)
6374+
kdf = align_diff_frames(assign_columns, self, value, fillna=False, how="left")
63796375
elif isinstance(key, (tuple, list)):
63806376
assert isinstance(value, DataFrame)
63816377
# Same DataFrames.
63826378
field_names = value.columns
63836379
kdf = self.assign(**{k: value[c] for k, c in zip(key, field_names)})
63846380
else:
6385-
# Same anchor DataFrames.
6381+
# Same Series.
63866382
kdf = self.assign(**{key: value})
63876383

63886384
self._internal = kdf._internal

databricks/koalas/tests/test_ops_on_diff_frames.py

+94
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,23 @@ def pdf4(self):
6262
'f': [2, 2, 2, 2, 2, 2, 2, 2, 2],
6363
}, index=list(range(9)))
6464

65+
@property
66+
def pdf5(self):
67+
return pd.DataFrame({
68+
'a': [1, 2, 3, 4, 5, 6, 7, 8, 9],
69+
'b': [4, 5, 6, 3, 2, 1, 0, 0, 0],
70+
'c': [4, 5, 6, 3, 2, 1, 0, 0, 0],
71+
}, index=[0, 1, 3, 5, 6, 8, 9, 10, 11]).set_index(['a', 'b'])
72+
73+
@property
74+
def pdf6(self):
75+
return pd.DataFrame({
76+
'a': [9, 8, 7, 6, 5, 4, 3, 2, 1],
77+
'b': [0, 0, 0, 4, 5, 6, 1, 2, 3],
78+
'c': [9, 8, 7, 6, 5, 4, 3, 2, 1],
79+
'e': [4, 5, 6, 3, 2, 1, 0, 0, 0],
80+
}, index=list(range(9))).set_index(['a', 'b'])
81+
6582
@property
6683
def kdf1(self):
6784
return ks.from_pandas(self.pdf1)
@@ -78,6 +95,23 @@ def kdf3(self):
7895
def kdf4(self):
7996
return ks.from_pandas(self.pdf4)
8097

98+
@property
99+
def kdf5(self):
100+
return ks.from_pandas(self.pdf5)
101+
102+
@property
103+
def kdf6(self):
104+
return ks.from_pandas(self.pdf6)
105+
106+
def test_no_index(self):
107+
with self.assertRaisesRegex(AssertionError, "cannot join with no overlapping index name"):
108+
ks.range(10) + ks.range(10)
109+
110+
def test_no_matched_index(self):
111+
with self.assertRaisesRegex(AssertionError, "cannot join with no overlapping index name"):
112+
ks.DataFrame({'a': [1, 2, 3]}).set_index('a') + \
113+
ks.DataFrame({'b': [1, 2, 3]}).set_index('b')
114+
81115
def test_arithmetic(self):
82116
# Series
83117
self.assertEqual(
@@ -196,6 +230,66 @@ def test_assignment_frame_chain(self):
196230

197231
self.assert_eq(kdf.sort_index(), pdf.sort_index())
198232

233+
def test_multi_index_arithmetic(self):
234+
# Series
235+
self.assertEqual(
236+
repr((self.kdf5.c - self.kdf6.e).sort_index()),
237+
repr((self.pdf5.c - self.pdf6.e).rename("c").sort_index()))
238+
239+
self.assertEqual(
240+
repr((self.kdf5["c"] / self.kdf6["e"]).sort_index()),
241+
repr((self.pdf5["c"] / self.pdf6["e"]).rename("c").sort_index()))
242+
243+
# DataFrame
244+
self.assert_eq(
245+
repr((self.kdf5 + self.kdf6).sort_index()),
246+
repr((self.pdf5 + self.pdf6).sort_index()))
247+
248+
def test_multi_index_assignment_series(self):
249+
kdf = ks.from_pandas(self.pdf5)
250+
pdf = self.pdf5.copy()
251+
kdf['x'] = self.kdf6.e
252+
pdf['x'] = self.pdf6.e
253+
254+
self.assert_eq(kdf.sort_index(), pdf.sort_index())
255+
256+
kdf = ks.from_pandas(self.pdf5)
257+
pdf = self.pdf5.copy()
258+
kdf['x'] = self.kdf6.e
259+
pdf['x'] = self.pdf6.e
260+
261+
self.assert_eq(kdf.sort_index(), pdf.sort_index())
262+
263+
kdf = ks.from_pandas(self.pdf5)
264+
pdf = self.pdf5.copy()
265+
kdf['c'] = self.kdf6.e
266+
267+
pdf['c'] = self.pdf6.e
268+
269+
self.assert_eq(kdf.sort_index(), pdf.sort_index())
270+
271+
def test_multi_index_assignment_frame(self):
272+
kdf = ks.from_pandas(self.pdf5)
273+
pdf = self.pdf5.copy()
274+
kdf[['c']] = self.kdf5
275+
pdf[['c']] = self.pdf5
276+
277+
self.assert_eq(kdf.sort_index(), pdf.sort_index())
278+
279+
kdf = ks.from_pandas(self.pdf5)
280+
pdf = self.pdf5.copy()
281+
kdf[['x']] = self.kdf5
282+
pdf[['x']] = self.pdf5
283+
284+
self.assert_eq(kdf.sort_index(), pdf.sort_index())
285+
286+
kdf = ks.from_pandas(self.pdf6)
287+
pdf = self.pdf6.copy()
288+
kdf[['x', 'y']] = self.kdf6
289+
pdf[['x', 'y']] = self.pdf6
290+
291+
self.assert_eq(kdf.sort_index(), pdf.sort_index())
292+
199293

200294
class OpsOnDiffFramesDisabledTest(ReusedSQLTestCase, SQLTestUtils):
201295

databricks/koalas/utils.py

+23-17
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ def combine_frames(this, *args, how="full"):
3535
This method combines `this` DataFrame with a different `that` DataFrame or
3636
Series from a different DataFrame.
3737
38-
It returns a dataframe that has prefix, `this_` and `that_` to distinct
39-
the columns names.
38+
It returns a DataFrame that has prefix `this_` and `that_` to distinct
39+
the columns names from both DataFrames
4040
4141
It internally performs a join operation which can be expensive in general.
4242
So, if `OPS_ON_DIFF_FRAMES` environment variable is not set,
@@ -70,11 +70,12 @@ def combine_frames(this, *args, how="full"):
7070
join_scols = []
7171
merged_index_scols = []
7272

73+
# If the same named index is found, that's used.
7374
for this_column, this_name in this_index_map:
7475
for that_col, that_name in that_index_map:
7576
if this_name == that_name:
76-
# We should map the actual Spark columns even if
77-
# the index names are the name.
77+
# We should merge the Spark columns into one
78+
# to mimic pandas' behavior.
7879
this_scol = this._internal.scol_for(this_column)
7980
that_scol = that._internal.scol_for(that_col)
8081
join_scol = this_scol == that_scol
@@ -85,7 +86,7 @@ def combine_frames(this, *args, how="full"):
8586
).otherwise(that_scol).alias(this_column))
8687
break
8788
else:
88-
raise ValueError("Index names must be matched.")
89+
raise ValueError("Index names must be exactly matched currently.")
8990

9091
assert len(join_scols) > 0, "cannot join with no overlapping index names"
9192

@@ -106,13 +107,13 @@ def combine_frames(this, *args, how="full"):
106107
"it comes from a different dataframe")
107108

108109

109-
def align_diff_frames(func, this, that, fillna=True, how="full", include_all_that_columns=False):
110+
def align_diff_frames(resolve_func, this, that, fillna=True, how="full"):
110111
"""
111112
This method aligns two different DataFrames with a given `func`. Columns are resolved and
112113
handled within the given `func`.
113114
To use this, `OPS_ON_DIFF_FRAMES` environment variable should be enabled, for now.
114115
115-
:param func: Takes aligned (joined) DataFrame, the column of the current DataFrame, and
116+
:param resolve_func: Takes aligned (joined) DataFrame, the column of the current DataFrame, and
116117
the column of another DataFrame. It returns an iterable that produces Series.
117118
118119
>>> import os
@@ -152,15 +153,19 @@ def align_diff_frames(func, this, that, fillna=True, how="full", include_all_tha
152153
:param that: another DataFrame to align
153154
:param fillna: If True, it fills missing values in non-common columns in both `this` and `that`.
154155
Otherwise, it returns as are.
155-
:param how: join way.
156-
:param include_all_that_columns: If True, all non-common columns from `that` are added into
157-
`that_columns` into `func`, and they are excluded in non-common columns group
158-
(controlled by `fillna`). Otherwise, `this_columns` and `that_columns` will always in the
159-
same common column group.
156+
:param how: join way. In addition, it affects how `resolve_func` resolves the column conflict.
157+
- full: `resolve_func` should resolve only common columns from 'this' and 'that' DataFrames.
158+
For instance, if 'this' has columns A, B, C and that has B, C, D, `this_columns` and
159+
'that_columns' in this function are B, C and B, C.
160+
- left: `resolve_func` should resolve columns including that columns.
161+
For instance, if 'this' has columns A, B, C and that has B, C, D, `this_columns` is
162+
B, C but `that_columns` are B, C, D.
160163
:return: Alined DataFrame
161164
"""
162165
from databricks.koalas import DataFrame
163166

167+
assert how == "full" or how == "left"
168+
164169
this_data_columns = this._internal.data_columns
165170
that_data_columns = that._internal.data_columns
166171
common_columns = set(this_data_columns).intersection(that_data_columns)
@@ -185,11 +190,12 @@ def align_diff_frames(func, this, that, fillna=True, how="full", include_all_tha
185190
that_columns_to_apply.append(combined_column)
186191
break
187192
else:
188-
if include_all_that_columns and \
193+
if how == "left" and \
189194
combined_column in ["__that_%s" % c for c in that_data_columns]:
190-
# In this case, we will drop that columns in columns to keep but passes it later
191-
# to `func`. Note that adding this into a separate list is intentional so that
192-
# `this_columns` and `that_columns` can be paired.
195+
# In this case, we will drop `that_columns` in `columns_to_keep` but passes
196+
# it later to `func`. `func` should resolve it.
197+
# Note that adding this into a separate list (`additional_that_columns`)
198+
# is intentional so that `this_columns` and `that_columns` can be paired.
193199
additional_that_columns.append(combined_column)
194200
elif fillna:
195201
columns_to_keep.append(F.lit(None).cast(FloatType()).alias(combined_column))
@@ -200,7 +206,7 @@ def align_diff_frames(func, this, that, fillna=True, how="full", include_all_tha
200206

201207
# Should extract columns to apply and do it in a batch in case
202208
# it adds new columns for example.
203-
kser_set = list(func(combined, this_columns_to_apply, that_columns_to_apply))
209+
kser_set = list(resolve_func(combined, this_columns_to_apply, that_columns_to_apply))
204210
columns_applied = [c._scol for c in kser_set]
205211

206212
sdf = combined._sdf.select(

0 commit comments

Comments
 (0)