Skip to content

Commit

Permalink
Refactor LocIndexerLike__getitem__. (#1152)
Browse files Browse the repository at this point in the history
  • Loading branch information
ueshin authored Dec 26, 2019
1 parent 0541a5d commit 899b346
Showing 1 changed file with 77 additions and 103 deletions.
180 changes: 77 additions & 103 deletions databricks/koalas/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,49 +41,6 @@ def _make_col(c):
description="Can only convert a string to a column type.")


def _unfold(key, kseries):
""" Return row selection and column selection pair.
If kseries parameter is not None, the key should be row selection and the column selection will
be the kseries parameter.
>>> s = ks.Series([1, 2, 3], name='a')
>>> _unfold(slice(1, 2), s)
(slice(1, 2, None), 0 1
1 2
2 3
Name: a, dtype: int64)
>>> _unfold((slice(1, 2), slice(None)), None)
(slice(1, 2, None), slice(None, None, None))
>>> _unfold((slice(1, 2), s), None)
(slice(1, 2, None), 0 1
1 2
2 3
Name: a, dtype: int64)
>>> _unfold((slice(1, 2), 'col'), None)
(slice(1, 2, None), 'col')
"""
if kseries is not None:
if isinstance(key, tuple):
if len(key) > 1:
raise SparkPandasIndexingError('Too many indexers')
key = key[0]
rows_sel = key
cols_sel = kseries
elif isinstance(key, tuple):
if len(key) != 2:
raise SparkPandasIndexingError("Only accepts pairs of candidates")
rows_sel, cols_sel = key
else:
rows_sel = key
cols_sel = None

return rows_sel, cols_sel


class _IndexerLike(object):

def __init__(self, kdf_or_kser):
Expand Down Expand Up @@ -190,51 +147,70 @@ def __getitem__(self, key):
from databricks.koalas.frame import DataFrame
from databricks.koalas.series import Series

rows_sel, cols_sel = _unfold(key, self._kdf_or_kser if self._is_series else None)
if self._is_series:
if isinstance(key, tuple):
if len(key) > 1:
raise SparkPandasIndexingError('Too many indexers')
key = key[0]

cond, limit = self._select_rows(rows_sel)
column_index, columns, returns_series = self._select_cols(cols_sel)
cond, limit = self._select_rows(key)
if cond is None and limit is None:
return self._kdf_or_kser

if cond is None and limit is None and returns_series:
if self._is_series:
return self._kdf_or_kser._with_new_scol(columns[0])
column_index = self._internal.column_index
column_scols = self._internal.column_scols
returns_series = True
else:
assert self._is_df
if isinstance(key, tuple):
if len(key) != 2:
raise SparkPandasIndexingError("Only accepts pairs of candidates")
rows_sel, cols_sel = key
else:
return Series(self._internal.copy(scol=columns[0], column_index=[column_index[0]]),
rows_sel = key
cols_sel = None

cond, limit = self._select_rows(rows_sel)
column_index, column_scols, returns_series = self._select_cols(cols_sel)

if cond is None and limit is None and returns_series:
return Series(self._internal.copy(scol=column_scols[0],
column_index=[column_index[0]]),
anchor=self._kdf_or_kser)
else:
try:
sdf = self._internal._sdf
if cond is not None:
sdf = sdf.where(cond)
if limit is not None:
if limit >= 0:
sdf = sdf.limit(limit)
else:
sdf = sdf.limit(sdf.count() + limit)

sdf = sdf.select(self._internal.index_scols + columns)

if self._internal.column_index_names is None:
column_index_names = None

try:
sdf = self._internal._sdf
if cond is not None:
sdf = sdf.where(cond)
if limit is not None:
if limit >= 0:
sdf = sdf.limit(limit)
else:
# Manage column index names
level = column_index_level(column_index)
column_index_names = self._internal.column_index_names[-level:]

internal = _InternalFrame(sdf=sdf,
index_map=self._internal.index_map,
column_index=column_index,
column_index_names=column_index_names)
kdf = DataFrame(internal)
except AnalysisException:
raise KeyError('[{}] don\'t exist in columns'
.format([col._jc.toString() for col in columns]))
sdf = sdf.limit(sdf.count() + limit)

if returns_series:
return Series(kdf._internal.copy(scol=kdf._internal.column_scols[0]),
anchor=kdf)
sdf = sdf.select(self._internal.index_scols + column_scols)

if self._internal.column_index_names is None:
column_index_names = None
else:
return kdf
# Manage column index names
level = column_index_level(column_index)
column_index_names = self._internal.column_index_names[-level:]

internal = _InternalFrame(sdf=sdf,
index_map=self._internal.index_map,
column_index=column_index,
column_index_names=column_index_names)
kdf = DataFrame(internal)
except AnalysisException:
raise KeyError('[{}] don\'t exist in columns'
.format([col._jc.toString() for col in column_scols]))

if returns_series:
return Series(kdf._internal.copy(scol=kdf._internal.column_scols[0]),
anchor=kdf)
else:
return kdf


class LocIndexer(_LocIndexerLike):
Expand Down Expand Up @@ -488,21 +464,20 @@ def _get_from_multiindex_column(self, key, indexes=None):
if all(len(idx) > 0 and idx[0] == '' for _, idx in indexes):
# If the head is '', drill down recursively.
indexes = [(col, tuple([str(key), *idx[1:]])) for i, (col, idx) in enumerate(indexes)]
column_index, columns, returns_series = \
self._get_from_multiindex_column((str(key),), indexes)
return self._get_from_multiindex_column((str(key),), indexes)
else:
returns_series = all(len(idx) == 0 for _, idx in indexes)
if returns_series:
idxes = set(idx for idx, _ in indexes)
assert len(idxes) == 1
index = list(idxes)[0]
column_index = [index]
columns = [self._internal.scol_for(index)]
column_scols = [self._internal.scol_for(index)]
else:
column_index = [idx for _, idx in indexes]
columns = [self._internal.scol_for(idx) for idx, _ in indexes]
column_scols = [self._internal.scol_for(idx) for idx, _ in indexes]

return column_index, columns, returns_series
return column_index, column_scols, returns_series

def _select_cols(self, cols_sel):
from databricks.koalas.series import Series
Expand All @@ -520,21 +495,20 @@ def _select_cols(self, cols_sel):

if cols_sel is None:
column_index = self._internal.column_index
columns = self._internal.column_scols
column_scols = self._internal.column_scols
elif isinstance(cols_sel, (str, tuple)):
if isinstance(cols_sel, str):
cols_sel = (cols_sel,)
column_index, columns, returns_series = \
self._get_from_multiindex_column(cols_sel)
return self._get_from_multiindex_column(cols_sel)
elif isinstance(cols_sel, spark.Column):
columns = [cols_sel]
column_index = None
column_index = [(self._internal.sdf.select(cols_sel).columns[0],)]
column_scols = [cols_sel]
elif all(isinstance(key, Series) for key in cols_sel):
columns = [_make_col(key) for key in cols_sel]
column_index = [key._internal.column_index[0] for key in cols_sel]
column_scols = [_make_col(key) for key in cols_sel]
elif all(isinstance(key, spark.Column) for key in cols_sel):
columns = cols_sel
column_index = None
column_index = [(self._internal.sdf.select(col).columns[0],) for col in cols_sel]
column_scols = cols_sel
elif (any(isinstance(key, str) for key in cols_sel)
and any(isinstance(key, tuple) for key in cols_sel)):
raise TypeError('Expected tuple, got str')
Expand All @@ -546,19 +520,19 @@ def _select_cols(self, cols_sel):

column_to_index = list(zip(self._internal.data_columns,
self._internal.column_index))
columns = []
column_index = []
column_scols = []
for key in cols_sel:
found = False
for column, idx in column_to_index:
if idx == key or idx[0] == key:
columns.append(_make_col(column))
column_index.append(idx)
column_scols.append(_make_col(column))
found = True
if not found:
raise KeyError("['{}'] not in index".format(name_like_string(key)))

return column_index, columns, returns_series
return column_index, column_scols, returns_series

def __setitem__(self, key, value):
from databricks.koalas.frame import DataFrame
Expand Down Expand Up @@ -768,19 +742,19 @@ def _select_cols(self, cols_sel):

# make cols_sel a 1-tuple of string if a single string
if isinstance(cols_sel, Series) and cols_sel._equals(self._kdf_or_kser):
columns = cols_sel._internal.column_scols
column_index = cols_sel._internal.column_index
column_scols = cols_sel._internal.column_scols
elif isinstance(cols_sel, int):
columns = [self._internal.column_scols[cols_sel]]
column_index = [self._internal.column_index[cols_sel]]
column_scols = [self._internal.column_scols[cols_sel]]
elif cols_sel is None or cols_sel == slice(None):
columns = self._internal.column_scols
column_index = self._internal.column_index
column_scols = self._internal.column_scols
elif isinstance(cols_sel, slice):
if all(s is None or isinstance(s, int)
for s in (cols_sel.start, cols_sel.stop, cols_sel.step)):
columns = self._internal.column_scols[cols_sel]
column_index = self._internal.column_index[cols_sel]
column_scols = self._internal.column_scols[cols_sel]
else:
not_none = cols_sel.start if cols_sel.start is not None \
else cols_sel.stop if cols_sel.stop is not None else cols_sel.step
Expand All @@ -790,12 +764,12 @@ def _select_cols(self, cols_sel):
if all(isinstance(s, bool) for s in cols_sel):
cols_sel = [i for i, s in enumerate(cols_sel) if s]
if all(isinstance(s, int) for s in cols_sel):
columns = [self._internal.column_scols[s] for s in cols_sel]
column_index = [self._internal.column_index[s] for s in cols_sel]
column_scols = [self._internal.column_scols[s] for s in cols_sel]
else:
raise TypeError('cannot perform reduce with flexible type')
else:
raise ValueError("Location based indexing can only have [integer, integer slice, "
"listlike of integers, boolean array] types, got {}".format(cols_sel))

return column_index, columns, returns_series
return column_index, column_scols, returns_series

0 comments on commit 899b346

Please sign in to comment.