Skip to content

Commit cf22a5d

Browse files
authored
Fix update and add tests for update and join. (#848)
1 parent e52e70f commit cf22a5d

File tree

2 files changed

+67
-6
lines changed

2 files changed

+67
-6
lines changed

databricks/koalas/frame.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -5379,7 +5379,8 @@ def merge(self, right: 'DataFrame', how: str = 'inner',
53795379
column_index=column_index)
53805380
return DataFrame(internal)
53815381

5382-
def join(self, right: 'DataFrame', on: Optional[Union[str, List[str]]] = None,
5382+
def join(self, right: 'DataFrame',
5383+
on: Optional[Union[str, List[str], Tuple[str, ...], List[Tuple[str, ...]]]] = None,
53835384
how: str = 'left', lsuffix: str = '', rsuffix: str = '') -> 'DataFrame':
53845385
"""
53855386
Join columns of another DataFrame.
@@ -5633,13 +5634,14 @@ def update(self, other: 'DataFrame', join: str = 'left', overwrite: bool = True)
56335634
if isinstance(other, ks.Series):
56345635
other = DataFrame(other)
56355636

5636-
update_columns = list(set(self._internal.data_columns)
5637-
.intersection(set(other._internal.data_columns)))
5637+
update_columns = list(set(self._internal.column_index)
5638+
.intersection(set(other._internal.column_index)))
56385639
update_sdf = self.join(other[update_columns], rsuffix='_new')._sdf
56395640

5640-
for column_name in update_columns:
5641+
for column_index in update_columns:
5642+
column_name = self._internal.column_name_for(column_index)
56415643
old_col = scol_for(update_sdf, column_name)
5642-
new_col = scol_for(update_sdf, column_name + '_new')
5644+
new_col = scol_for(update_sdf, other._internal.column_name_for(column_index) + '_new')
56435645
if overwrite:
56445646
update_sdf = update_sdf.withColumn(column_name, F.when(new_col.isNull(), old_col)
56455647
.otherwise(new_col))

databricks/koalas/tests/test_dataframe.py

+60-1
Original file line numberDiff line numberDiff line change
@@ -1143,6 +1143,33 @@ def test_join(self):
11431143
join_kdf.sort_values(by=list(join_kdf.columns), inplace=True)
11441144
self.assert_eq(join_pdf.reset_index(drop=True), join_kdf.reset_index(drop=True))
11451145

1146+
# multi-index columns
1147+
columns1 = pd.MultiIndex.from_tuples([('x', 'key'), ('Y', 'A')])
1148+
columns2 = pd.MultiIndex.from_tuples([('x', 'key'), ('Y', 'B')])
1149+
pdf1.columns = columns1
1150+
pdf2.columns = columns2
1151+
kdf1.columns = columns1
1152+
kdf2.columns = columns2
1153+
1154+
join_pdf = pdf1.join(pdf2, lsuffix='_left', rsuffix='_right')
1155+
join_pdf.sort_values(by=list(join_pdf.columns), inplace=True)
1156+
1157+
join_kdf = kdf1.join(kdf2, lsuffix='_left', rsuffix='_right')
1158+
join_kdf.sort_values(by=list(join_kdf.columns), inplace=True)
1159+
1160+
self.assert_eq(join_pdf, join_kdf)
1161+
1162+
# check `on` parameter
1163+
join_pdf = pdf1.join(pdf2.set_index(('x', 'key')), on=[('x', 'key')],
1164+
lsuffix='_left', rsuffix='_right')
1165+
join_pdf.sort_values(by=list(join_pdf.columns), inplace=True)
1166+
1167+
join_kdf = kdf1.join(kdf2.set_index(('x', 'key')), on=[('x', 'key')],
1168+
lsuffix='_left', rsuffix='_right')
1169+
join_kdf.sort_values(by=list(join_kdf.columns), inplace=True)
1170+
1171+
self.assert_eq(join_pdf.reset_index(drop=True), join_kdf.reset_index(drop=True))
1172+
11461173
def test_replace(self):
11471174
pdf = pd.DataFrame({"name": ['Ironman', 'Captain America', 'Thor', 'Hulk'],
11481175
"weapon": ['Mark-45', 'Shield', 'Mjolnir', 'Smash']})
@@ -1195,7 +1222,7 @@ def test_replace(self):
11951222

11961223
def test_update(self):
11971224
# check base function
1198-
def get_data():
1225+
def get_data(left_columns=None, right_columns=None):
11991226
left_pdf = pd.DataFrame({'A': ['1', '2', '3', '4'],
12001227
'B': ['100', '200', np.nan, np.nan]},
12011228
columns=['A', 'B'])
@@ -1206,6 +1233,12 @@ def get_data():
12061233
columns=['A', 'B'])
12071234
right_kdf = ks.DataFrame({'B': ['x', None, 'y', None],
12081235
'C': ['100', '200', '300', '400']}, columns=['B', 'C'])
1236+
if left_columns is not None:
1237+
left_pdf.columns = left_columns
1238+
left_kdf.columns = left_columns
1239+
if right_columns is not None:
1240+
right_pdf.columns = right_columns
1241+
right_kdf.columns = right_columns
12091242
return left_kdf, left_pdf, right_kdf, right_pdf
12101243

12111244
left_kdf, left_pdf, right_kdf, right_pdf = get_data()
@@ -1221,6 +1254,32 @@ def get_data():
12211254
with self.assertRaises(NotImplementedError):
12221255
left_kdf.update(right_kdf, join='right')
12231256

1257+
# multi-index columns
1258+
left_columns = pd.MultiIndex.from_tuples([('X', 'A'), ('X', 'B')])
1259+
right_columns = pd.MultiIndex.from_tuples([('X', 'B'), ('Y', 'C')])
1260+
1261+
left_kdf, left_pdf, right_kdf, right_pdf = get_data(left_columns=left_columns,
1262+
right_columns=right_columns)
1263+
left_pdf.update(right_pdf)
1264+
left_kdf.update(right_kdf)
1265+
self.assert_eq(left_pdf.sort_values(by=[('X', 'A'), ('X', 'B')]),
1266+
left_kdf.sort_values(by=[('X', 'A'), ('X', 'B')]))
1267+
1268+
left_kdf, left_pdf, right_kdf, right_pdf = get_data(left_columns=left_columns,
1269+
right_columns=right_columns)
1270+
left_pdf.update(right_pdf, overwrite=False)
1271+
left_kdf.update(right_kdf, overwrite=False)
1272+
self.assert_eq(left_pdf.sort_values(by=[('X', 'A'), ('X', 'B')]),
1273+
left_kdf.sort_values(by=[('X', 'A'), ('X', 'B')]))
1274+
1275+
right_columns = pd.MultiIndex.from_tuples([('Y', 'B'), ('Y', 'C')])
1276+
left_kdf, left_pdf, right_kdf, right_pdf = get_data(left_columns=left_columns,
1277+
right_columns=right_columns)
1278+
left_pdf.update(right_pdf)
1279+
left_kdf.update(right_kdf)
1280+
self.assert_eq(left_pdf.sort_values(by=[('X', 'A'), ('X', 'B')]),
1281+
left_kdf.sort_values(by=[('X', 'A'), ('X', 'B')]))
1282+
12241283
def test_pivot_table_dtypes(self):
12251284
pdf = pd.DataFrame({'a': [4, 2, 3, 4, 8, 6],
12261285
'b': [1, 2, 2, 4, 2, 4],

0 commit comments

Comments
 (0)