Skip to content

Commit 1879618

Browse files
authored
DataFrame.reindex(fill_value) does not fill existing NaN values (#1723)
1 parent c82e2a2 commit 1879618

File tree

2 files changed

+54
-18
lines changed

2 files changed

+54
-18
lines changed

databricks/koalas/frame.py

+34-14
Original file line numberDiff line numberDiff line change
@@ -7786,22 +7786,18 @@ def reindex(
77867786
df = self
77877787

77887788
if index is not None:
7789-
df = df._reindex_index(index)
7789+
df = df._reindex_index(index, fill_value)
77907790

77917791
if columns is not None:
7792-
df = df._reindex_columns(columns)
7793-
7794-
# Process missing values.
7795-
if fill_value is not None:
7796-
df = df.fillna(fill_value)
7792+
df = df._reindex_columns(columns, fill_value)
77977793

77987794
# Copy
7799-
if copy:
7795+
if copy and df is self:
78007796
return df.copy()
78017797
else:
78027798
return df
78037799

7804-
def _reindex_index(self, index):
7800+
def _reindex_index(self, index, fill_value):
78057801
# When axis is index, we can mimic pandas' by a right outer join.
78067802
assert (
78077803
len(self._internal.index_spark_column_names) <= 1
@@ -7811,15 +7807,38 @@ def _reindex_index(self, index):
78117807

78127808
kser = ks.Series(list(index))
78137809
labels = kser._internal.spark_frame.select(kser.spark.column.alias(index_column))
7810+
frame = self._internal.resolved_copy.spark_frame.drop(NATURAL_ORDER_COLUMN_NAME)
78147811

7815-
joined_df = self._internal.resolved_copy.spark_frame.drop(NATURAL_ORDER_COLUMN_NAME).join(
7816-
labels, on=index_column, how="right"
7817-
)
7818-
internal = self._internal.with_new_sdf(joined_df)
7812+
if fill_value is not None:
7813+
frame_index_column = verify_temp_column_name(frame, "__frame_index_column__")
7814+
frame = frame.withColumnRenamed(index_column, frame_index_column)
7815+
7816+
temp_fill_value = verify_temp_column_name(frame, "__fill_value__")
7817+
labels = labels.withColumn(temp_fill_value, F.lit(fill_value))
7818+
7819+
frame_index_scol = scol_for(frame, frame_index_column)
7820+
labels_index_scol = scol_for(labels, index_column)
78197821

7822+
joined_df = frame.join(labels, on=[frame_index_scol == labels_index_scol], how="right")
7823+
joined_df = joined_df.select(
7824+
labels_index_scol,
7825+
*[
7826+
F.when(
7827+
frame_index_scol.isNull() & labels_index_scol.isNotNull(),
7828+
scol_for(joined_df, temp_fill_value),
7829+
)
7830+
.otherwise(scol_for(joined_df, col))
7831+
.alias(col)
7832+
for col in self._internal.data_spark_column_names
7833+
]
7834+
)
7835+
else:
7836+
joined_df = frame.join(labels, on=index_column, how="right")
7837+
7838+
internal = self._internal.with_new_sdf(joined_df)
78207839
return DataFrame(internal)
78217840

7822-
def _reindex_columns(self, columns):
7841+
def _reindex_columns(self, columns, fill_value):
78237842
level = self._internal.column_labels_level
78247843
if level > 1:
78257844
label_columns = list(columns)
@@ -7833,12 +7852,13 @@ def _reindex_columns(self, columns):
78337852
raise ValueError(
78347853
"shape (1,{}) doesn't match the shape (1,{})".format(len(col), level)
78357854
)
7855+
fill_value = np.nan if fill_value is None else fill_value
78367856
scols, labels = [], []
78377857
for label in label_columns:
78387858
if label in self._internal.column_labels:
78397859
scols.append(self._internal.spark_column_for(label))
78407860
else:
7841-
scols.append(F.lit(np.nan).alias(name_like_string(label)))
7861+
scols.append(F.lit(fill_value).alias(name_like_string(label)))
78427862
labels.append(label)
78437863

78447864
return DataFrame(self._internal.with_new_columns(scols, column_labels=labels))

databricks/koalas/tests/test_dataframe.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -2367,8 +2367,8 @@ def test_drop_duplicates(self):
23672367

23682368
def test_reindex(self):
23692369
index = ["A", "B", "C", "D", "E"]
2370-
pdf = pd.DataFrame({"numbers": [1.0, 2.0, 3.0, 4.0, 5.0]}, index=index)
2371-
kdf = ks.DataFrame({"numbers": [1.0, 2.0, 3.0, 4.0, 5.0]}, index=index)
2370+
pdf = pd.DataFrame({"numbers": [1.0, 2.0, 3.0, 4.0, None]}, index=index)
2371+
kdf = ks.from_pandas(pdf)
23722372

23732373
self.assert_eq(
23742374
pdf.reindex(["A", "B", "C"], columns=["numbers", "2", "3"]).sort_index(),
@@ -2389,14 +2389,20 @@ def test_reindex(self):
23892389
kdf.reindex(index=["A", "B", "2", "3"]).sort_index(),
23902390
)
23912391

2392+
self.assert_eq(
2393+
pdf.reindex(index=["A", "E", "2", "3"], fill_value=0).sort_index(),
2394+
kdf.reindex(index=["A", "E", "2", "3"], fill_value=0).sort_index(),
2395+
)
2396+
23922397
self.assert_eq(
23932398
pdf.reindex(columns=["numbers"]).sort_index(),
23942399
kdf.reindex(columns=["numbers"]).sort_index(),
23952400
)
23962401

2402+
# Using float as fill_value to avoid int64/32 clash
23972403
self.assert_eq(
2398-
pdf.reindex(columns=["numbers", "2", "3"]).sort_index(),
2399-
kdf.reindex(columns=["numbers", "2", "3"]).sort_index(),
2404+
pdf.reindex(columns=["numbers", "2", "3"], fill_value=0.0).sort_index(),
2405+
kdf.reindex(columns=["numbers", "2", "3"], fill_value=0.0).sort_index(),
24002406
)
24012407

24022408
self.assertRaises(TypeError, lambda: kdf.reindex(columns=["numbers", "2", "3"], axis=1))
@@ -2413,6 +2419,16 @@ def test_reindex(self):
24132419
kdf.reindex(columns=[("X", "numbers"), ("Y", "2"), ("Y", "3")]).sort_index(),
24142420
)
24152421

2422+
# Using float as fill_value to avoid int64/32 clash
2423+
self.assert_eq(
2424+
pdf.reindex(
2425+
columns=[("X", "numbers"), ("Y", "2"), ("Y", "3")], fill_value=0.0
2426+
).sort_index(),
2427+
kdf.reindex(
2428+
columns=[("X", "numbers"), ("Y", "2"), ("Y", "3")], fill_value=0.0
2429+
).sort_index(),
2430+
)
2431+
24162432
self.assertRaises(TypeError, lambda: kdf.reindex(columns=["X"]))
24172433
self.assertRaises(ValueError, lambda: kdf.reindex(columns=[("X",)]))
24182434

0 commit comments

Comments
 (0)