Skip to content

Commit

Permalink
DataFrame.reindex(fill_value) does not fill existing NaN values
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasG0 committed Aug 25, 2020
1 parent a3394fd commit 353b55a
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 18 deletions.
48 changes: 34 additions & 14 deletions databricks/koalas/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -7794,22 +7794,18 @@ def reindex(
df = self

if index is not None:
df = df._reindex_index(index)
df = df._reindex_index(index, fill_value)

if columns is not None:
df = df._reindex_columns(columns)

# Process missing values.
if fill_value is not None:
df = df.fillna(fill_value)
df = df._reindex_columns(columns, fill_value)

# Copy
if copy:
if copy and df is self:
return df.copy()
else:
return df

def _reindex_index(self, index):
def _reindex_index(self, index, fill_value):
# When axis is index, we can mimic pandas' by a right outer join.
assert (
len(self._internal.index_spark_column_names) <= 1
Expand All @@ -7819,15 +7815,38 @@ def _reindex_index(self, index):

kser = ks.Series(list(index))
labels = kser._internal.spark_frame.select(kser.spark.column.alias(index_column))
frame = self._internal.resolved_copy.spark_frame.drop(NATURAL_ORDER_COLUMN_NAME)

joined_df = self._internal.resolved_copy.spark_frame.drop(NATURAL_ORDER_COLUMN_NAME).join(
labels, on=index_column, how="right"
)
internal = self._internal.with_new_sdf(joined_df)
if fill_value is not None:
frame_index_column = verify_temp_column_name(frame, "__frame_index_column__")
frame = frame.withColumnRenamed(index_column, frame_index_column)

temp_fill_value = verify_temp_column_name(frame, "__fill_value__")
labels = labels.withColumn(temp_fill_value, F.lit(fill_value))

frame_index_scol = scol_for(frame, frame_index_column)
labels_index_scol = scol_for(labels, index_column)

joined_df = frame.join(labels, on=[frame_index_scol == labels_index_scol], how="right")
joined_df = joined_df.select(
labels_index_scol,
*[
F.when(
frame_index_scol.isNull() & labels_index_scol.isNotNull(),
scol_for(joined_df, temp_fill_value),
)
.otherwise(scol_for(joined_df, col))
.alias(col)
for col in self._internal.data_spark_column_names
]
)
else:
joined_df = frame.join(labels, on=index_column, how="right")

internal = self._internal.with_new_sdf(joined_df)
return DataFrame(internal)

def _reindex_columns(self, columns):
def _reindex_columns(self, columns, fill_value):
level = self._internal.column_labels_level
if level > 1:
label_columns = list(columns)
Expand All @@ -7841,12 +7860,13 @@ def _reindex_columns(self, columns):
raise ValueError(
"shape (1,{}) doesn't match the shape (1,{})".format(len(col), level)
)
fill_value = np.nan if fill_value is None else fill_value
scols, labels = [], []
for label in label_columns:
if label in self._internal.column_labels:
scols.append(self._internal.spark_column_for(label))
else:
scols.append(F.lit(np.nan).alias(name_like_string(label)))
scols.append(F.lit(fill_value).alias(name_like_string(label)))
labels.append(label)

return DataFrame(self._internal.with_new_columns(scols, column_labels=labels))
Expand Down
24 changes: 20 additions & 4 deletions databricks/koalas/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2375,8 +2375,8 @@ def test_drop_duplicates(self):

def test_reindex(self):
index = ["A", "B", "C", "D", "E"]
pdf = pd.DataFrame({"numbers": [1.0, 2.0, 3.0, 4.0, 5.0]}, index=index)
kdf = ks.DataFrame({"numbers": [1.0, 2.0, 3.0, 4.0, 5.0]}, index=index)
pdf = pd.DataFrame({"numbers": [1.0, 2.0, 3.0, 4.0, None]}, index=index)
kdf = ks.DataFrame({"numbers": [1.0, 2.0, 3.0, 4.0, None]}, index=index)

self.assert_eq(
pdf.reindex(["A", "B", "C"], columns=["numbers", "2", "3"]).sort_index(),
Expand All @@ -2397,14 +2397,20 @@ def test_reindex(self):
kdf.reindex(index=["A", "B", "2", "3"]).sort_index(),
)

self.assert_eq(
pdf.reindex(index=["A", "E", "2", "3"], fill_value=0).sort_index(),
kdf.reindex(index=["A", "E", "2", "3"], fill_value=0).sort_index(),
)

self.assert_eq(
pdf.reindex(columns=["numbers"]).sort_index(),
kdf.reindex(columns=["numbers"]).sort_index(),
)

# Using float as fill_value to avoid int64/32 clash
self.assert_eq(
pdf.reindex(columns=["numbers", "2", "3"]).sort_index(),
kdf.reindex(columns=["numbers", "2", "3"]).sort_index(),
pdf.reindex(columns=["numbers", "2", "3"], fill_value=0.0).sort_index(),
kdf.reindex(columns=["numbers", "2", "3"], fill_value=0.0).sort_index(),
)

self.assertRaises(TypeError, lambda: kdf.reindex(columns=["numbers", "2", "3"], axis=1))
Expand All @@ -2421,6 +2427,16 @@ def test_reindex(self):
kdf.reindex(columns=[("X", "numbers"), ("Y", "2"), ("Y", "3")]).sort_index(),
)

# Using float as fill_value to avoid int64/32 clash
self.assert_eq(
pdf.reindex(
columns=[("X", "numbers"), ("Y", "2"), ("Y", "3")], fill_value=0.0
).sort_index(),
kdf.reindex(
columns=[("X", "numbers"), ("Y", "2"), ("Y", "3")], fill_value=0.0
).sort_index(),
)

self.assertRaises(TypeError, lambda: kdf.reindex(columns=["X"]))
self.assertRaises(ValueError, lambda: kdf.reindex(columns=[("X",)]))

Expand Down

0 comments on commit 353b55a

Please sign in to comment.