Skip to content

Commit 5a7c2fc

Browse files
authored
Explicitly disallow empty list as index_spark_colum_names and index_names. (#1895)
When an empty list is passed to `IndexFrame.index_spark_column_names` or `.index_names`, currently it passes without any errors and it will cause an unexpected error in other places. We should handle the empty list the same as `None`.
1 parent 8c4dd59 commit 5a7c2fc

File tree

4 files changed

+22
-35
lines changed

4 files changed

+22
-35
lines changed

databricks/koalas/frame.py

+4-15
Original file line numberDiff line numberDiff line change
@@ -3501,21 +3501,10 @@ def rename(index):
35013501
scol_for(sdf, column).alias(name_like_string(name)) for column, name in new_index_map
35023502
]
35033503

3504-
if len(index_map) > 0: # type: ignore
3505-
index_scols = [scol_for(sdf, column) for column in index_map]
3506-
sdf = sdf.select(
3507-
index_scols
3508-
+ new_data_scols
3509-
+ self._internal.data_spark_columns
3510-
+ list(HIDDEN_COLUMNS)
3511-
)
3512-
else:
3513-
sdf = sdf.select(
3514-
new_data_scols + self._internal.data_spark_columns + list(HIDDEN_COLUMNS)
3515-
)
3516-
3517-
sdf = InternalFrame.attach_default_index(sdf)
3518-
index_map = OrderedDict({SPARK_DEFAULT_INDEX_NAME: None})
3504+
index_scols = [scol_for(sdf, column) for column in index_map]
3505+
sdf = sdf.select(
3506+
index_scols + new_data_scols + self._internal.data_spark_columns + list(HIDDEN_COLUMNS)
3507+
)
35193508

35203509
if self._internal.column_labels_level > 1:
35213510
column_depth = len(self._internal.column_labels[0])

databricks/koalas/groupby.py

+11-13
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@
6161
from databricks.koalas.config import get_option
6262
from databricks.koalas.utils import (
6363
align_diff_frames,
64-
column_labels_level,
6564
is_name_like_tuple,
6665
is_name_like_value,
6766
name_like_string,
@@ -1318,29 +1317,28 @@ def _make_pandas_df_builder_func(kdf, func, return_schema, retain_index):
13181317
index_names = kdf._internal.index_names
13191318
data_columns = kdf._internal.data_spark_column_names
13201319
column_labels = kdf._internal.column_labels
1320+
column_labels_level = kdf._internal.column_labels_level
13211321

13221322
def rename_output(pdf):
13231323
# TODO: This logic below was borrowed from `DataFrame.to_pandas_frame` to set the index
13241324
# within each pdf properly. we might have to deduplicate it.
13251325
import pandas as pd
13261326

1327-
if len(index_columns) > 0:
1328-
append = False
1329-
for index_field in index_columns:
1330-
drop = index_field not in data_columns
1331-
pdf = pdf.set_index(index_field, drop=drop, append=append)
1332-
append = True
1333-
pdf = pdf[data_columns]
1327+
append = False
1328+
for index_field in index_columns:
1329+
drop = index_field not in data_columns
1330+
pdf = pdf.set_index(index_field, drop=drop, append=append)
1331+
append = True
1332+
pdf = pdf[data_columns]
13341333

1335-
if column_labels_level(column_labels) > 1:
1334+
if column_labels_level > 1:
13361335
pdf.columns = pd.MultiIndex.from_tuples(column_labels)
13371336
else:
13381337
pdf.columns = [None if label is None else label[0] for label in column_labels]
13391338

1340-
if len(index_names) > 0:
1341-
pdf.index.names = [
1342-
name if name is None or len(name) > 1 else name[0] for name in index_names
1343-
]
1339+
pdf.index.names = [
1340+
name if name is None or len(name) > 1 else name[0] for name in index_names
1341+
]
13441342

13451343
pdf = func(pdf)
13461344

databricks/koalas/internal.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ def __init__(
445445
assert isinstance(spark_frame, spark.DataFrame)
446446
assert not spark_frame.isStreaming, "Koalas does not support Structured Streaming."
447447

448-
if index_spark_column_names is None:
448+
if not index_spark_column_names:
449449
assert not any(SPARK_INDEX_NAME_PATTERN.match(name) for name in spark_frame.columns), (
450450
"Index columns should not appear in columns of the Spark DataFrame. Avoid "
451451
"index column names [%s]." % SPARK_INDEX_NAME_PATTERN
@@ -470,7 +470,7 @@ def __init__(
470470
NATURAL_ORDER_COLUMN_NAME, F.monotonically_increasing_id()
471471
)
472472

473-
if index_names is None:
473+
if not index_names:
474474
index_names = [None] * len(index_spark_column_names)
475475

476476
assert len(index_spark_column_names) == len(index_names), (
@@ -857,11 +857,10 @@ def to_pandas_frame(self) -> pd.DataFrame:
857857
name=names[0],
858858
)
859859

860-
index_names = self.index_names
861-
if len(index_names) > 0:
862-
pdf.index.names = [
863-
name if name is None or len(name) > 1 else name[0] for name in index_names
864-
]
860+
pdf.index.names = [
861+
name if name is None or len(name) > 1 else name[0] for name in self.index_names
862+
]
863+
865864
return pdf
866865

867866
@lazy_property

databricks/koalas/tests/test_dataframe.py

+1
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ def test_reset_index(self):
207207
kdf = ks.from_pandas(pdf)
208208

209209
self.assert_eq(kdf.reset_index(), pdf.reset_index())
210+
self.assert_eq(kdf.reset_index().index, pdf.reset_index().index)
210211
self.assert_eq(kdf.reset_index(drop=True), pdf.reset_index(drop=True))
211212

212213
pdf.index.name = "a"

0 commit comments

Comments
 (0)