Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix .loc to work properly when using 'slice' #1159

Merged
merged 14 commits into from
Jan 8, 2020
59 changes: 47 additions & 12 deletions databricks/koalas/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pandas.api.types import is_list_like
from pyspark import sql as spark
from pyspark.sql import functions as F
from pyspark.sql.types import BooleanType
from pyspark.sql.types import BooleanType, StringType, LongType
from pyspark.sql.utils import AnalysisException

from databricks.koalas.internal import _InternalFrame, HIDDEN_COLUMNS, NATURAL_ORDER_COLUMN_NAME
Expand Down Expand Up @@ -233,7 +233,6 @@ class LocIndexer(_LocIndexerLike):

.. note:: Note that contrary to usual python slices, **both** the
start and the stop are included, and the step of the slice is not allowed.
In addition, with a slice, Koalas works as a filter between the range.

.. note:: With a list or array of labels for row selection,
Koalas behaves as a filter without reordering by the labels.
Expand Down Expand Up @@ -290,13 +289,9 @@ class LocIndexer(_LocIndexerLike):
Slice with labels for row and single label for column. As mentioned
above, note that both the start and stop of the slice are included.

Also note that the row for 'sidewinder' is included since 'sidewinder'
is between 'cobra' and 'viper'.

>>> df.loc['cobra':'viper', 'max_speed']
cobra 1
viper 4
sidewinder 7
cobra 1
viper 4
Name: max_speed, dtype: int64

Conditional that returns a boolean Series
Expand Down Expand Up @@ -400,16 +395,56 @@ def _select_rows(self, rows_sel):
# If slice is None - select everything, so nothing to do
return None, None
elif len(self._internal.index_columns) == 1:
sdf = self._kdf_or_kser._internal.sdf
index = self._kdf_or_kser.index
index_column = index.to_series()
index_data_type = index_column.spark_type
start = rows_sel.start
stop = rows_sel.stop
start_order_column = sdf[NATURAL_ORDER_COLUMN_NAME]
stop_order_column = sdf[NATURAL_ORDER_COLUMN_NAME]

# get natural order from '__natural_order__' from start to stop
# to keep natural order.
start_and_stop = (
sdf.select(index_column._scol, NATURAL_ORDER_COLUMN_NAME)
.where((index_column._scol == start) | (index_column._scol == stop))
.collect())

start = [row[1] for row in start_and_stop if row[0] == start]
start = start[0] if len(start) > 0 else None

stop = [row[1] for row in start_and_stop if row[0] == stop]
stop = stop[0] if len(stop) > 0 else None
itholic marked this conversation as resolved.
Show resolved Hide resolved

# if index order is not monotonic increasing or decreasing
# and specified values don't exist in index, raise KeyError
if start is None and rows_sel.start is not None:
if not (index.is_monotonic_decreasing or index.is_monotonic_increasing):
itholic marked this conversation as resolved.
Show resolved Hide resolved
itholic marked this conversation as resolved.
Show resolved Hide resolved
raise KeyError(rows_sel.start)
else:
start = rows_sel.start
start_order_column = index_column._scol
if stop is None and rows_sel.stop is not None:
if not (index.is_monotonic_increasing or index.is_monotonic_decreasing):
itholic marked this conversation as resolved.
Show resolved Hide resolved
raise KeyError(rows_sel.stop)
else:
stop = rows_sel.stop
stop_order_column = index_column._scol

# we don't use StringType since we're using `__natural_order__` for comparing
if isinstance(index_data_type, StringType):
index_data_type = LongType()
itholic marked this conversation as resolved.
Show resolved Hide resolved

# if start and stop are same, just get all start(or stop) values
if start == stop:
return index_column._scol == F.lit(rows_sel.start).cast(index_data_type), None
itholic marked this conversation as resolved.
Show resolved Hide resolved

index_column = self._kdf_or_kser.index.to_series()
index_data_type = index_column.spark_type
cond = []
if start is not None:
cond.append(index_column._scol >= F.lit(start).cast(index_data_type))
cond.append(start_order_column >= F.lit(start).cast(index_data_type))
if stop is not None:
cond.append(index_column._scol <= F.lit(stop).cast(index_data_type))
cond.append(stop_order_column <= F.lit(stop).cast(index_data_type))

if len(cond) > 0:
return reduce(lambda x, y: x & y, cond), None
Expand Down
5 changes: 1 addition & 4 deletions databricks/koalas/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3441,10 +3441,7 @@ def truncate(self, before=None, after=None, copy=True):
if before > after:
raise ValueError("Truncate: %s must be after %s" % (after, before))

if indexes_increasing:
result = _col(self.to_frame()[before:after])
else:
result = _col(self.to_frame()[after:before])
result = _col(self.to_frame()[before:after])

return result.copy() if copy else result

Expand Down
14 changes: 14 additions & 0 deletions databricks/koalas/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,20 @@ def test_loc(self):
self.assert_eq(kdf.loc[1000:], pdf.loc[1000:])
self.assert_eq(kdf.loc[-2000:-1000], pdf.loc[-2000:-1000])

# KeyError when index is not monotonic increasing or decreasing
# and specified values don't exist in index
kdf = ks.DataFrame([[1, 2], [4, 5], [7, 8]],
index=['cobra', 'viper', 'sidewinder'])

self.assertRaises(KeyError, lambda: kdf.loc['cobra':'koalas'])
self.assertRaises(KeyError, lambda: kdf.loc['koalas':'viper'])

kdf = ks.DataFrame([[1, 2], [4, 5], [7, 8]],
index=[10, 30, 20])

self.assertRaises(KeyError, lambda: kdf.loc[0:30])
self.assertRaises(KeyError, lambda: kdf.loc[10:100])

def test_loc_non_informative_index(self):
pdf = pd.DataFrame({'x': [1, 2, 3, 4]}, index=[10, 20, 30, 40])
kdf = ks.from_pandas(pdf)
Expand Down