From 899b34661650d8c1f3b6ccaaf752232a706d59e8 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 26 Dec 2019 10:22:49 -0800 Subject: [PATCH] Refactor `LocIndexerLike__getitem__`. (#1152) --- databricks/koalas/indexing.py | 180 +++++++++++++++------------------- 1 file changed, 77 insertions(+), 103 deletions(-) diff --git a/databricks/koalas/indexing.py b/databricks/koalas/indexing.py index b34716f36e..15e4ba2ae7 100644 --- a/databricks/koalas/indexing.py +++ b/databricks/koalas/indexing.py @@ -41,49 +41,6 @@ def _make_col(c): description="Can only convert a string to a column type.") -def _unfold(key, kseries): - """ Return row selection and column selection pair. - - If kseries parameter is not None, the key should be row selection and the column selection will - be the kseries parameter. - - >>> s = ks.Series([1, 2, 3], name='a') - >>> _unfold(slice(1, 2), s) - (slice(1, 2, None), 0 1 - 1 2 - 2 3 - Name: a, dtype: int64) - - >>> _unfold((slice(1, 2), slice(None)), None) - (slice(1, 2, None), slice(None, None, None)) - - >>> _unfold((slice(1, 2), s), None) - (slice(1, 2, None), 0 1 - 1 2 - 2 3 - Name: a, dtype: int64) - - >>> _unfold((slice(1, 2), 'col'), None) - (slice(1, 2, None), 'col') - """ - if kseries is not None: - if isinstance(key, tuple): - if len(key) > 1: - raise SparkPandasIndexingError('Too many indexers') - key = key[0] - rows_sel = key - cols_sel = kseries - elif isinstance(key, tuple): - if len(key) != 2: - raise SparkPandasIndexingError("Only accepts pairs of candidates") - rows_sel, cols_sel = key - else: - rows_sel = key - cols_sel = None - - return rows_sel, cols_sel - - class _IndexerLike(object): def __init__(self, kdf_or_kser): @@ -190,51 +147,70 @@ def __getitem__(self, key): from databricks.koalas.frame import DataFrame from databricks.koalas.series import Series - rows_sel, cols_sel = _unfold(key, self._kdf_or_kser if self._is_series else None) + if self._is_series: + if isinstance(key, tuple): + if len(key) > 1: + raise SparkPandasIndexingError('Too many indexers') + key = key[0] - cond, limit = self._select_rows(rows_sel) - column_index, columns, returns_series = self._select_cols(cols_sel) + cond, limit = self._select_rows(key) + if cond is None and limit is None: + return self._kdf_or_kser - if cond is None and limit is None and returns_series: - if self._is_series: - return self._kdf_or_kser._with_new_scol(columns[0]) + column_index = self._internal.column_index + column_scols = self._internal.column_scols + returns_series = True + else: + assert self._is_df + if isinstance(key, tuple): + if len(key) != 2: + raise SparkPandasIndexingError("Only accepts pairs of candidates") + rows_sel, cols_sel = key else: - return Series(self._internal.copy(scol=columns[0], column_index=[column_index[0]]), + rows_sel = key + cols_sel = None + + cond, limit = self._select_rows(rows_sel) + column_index, column_scols, returns_series = self._select_cols(cols_sel) + + if cond is None and limit is None and returns_series: + return Series(self._internal.copy(scol=column_scols[0], + column_index=[column_index[0]]), anchor=self._kdf_or_kser) - else: - try: - sdf = self._internal._sdf - if cond is not None: - sdf = sdf.where(cond) - if limit is not None: - if limit >= 0: - sdf = sdf.limit(limit) - else: - sdf = sdf.limit(sdf.count() + limit) - - sdf = sdf.select(self._internal.index_scols + columns) - - if self._internal.column_index_names is None: - column_index_names = None + + try: + sdf = self._internal._sdf + if cond is not None: + sdf = sdf.where(cond) + if limit is not None: + if limit >= 0: + sdf = sdf.limit(limit) else: - # Manage column index names - level = column_index_level(column_index) - column_index_names = self._internal.column_index_names[-level:] - - internal = _InternalFrame(sdf=sdf, - index_map=self._internal.index_map, - column_index=column_index, - column_index_names=column_index_names) - kdf = DataFrame(internal) - except AnalysisException: - raise KeyError('[{}] don\'t exist in columns' - .format([col._jc.toString() for col in columns])) + sdf = sdf.limit(sdf.count() + limit) - if returns_series: - return Series(kdf._internal.copy(scol=kdf._internal.column_scols[0]), - anchor=kdf) + sdf = sdf.select(self._internal.index_scols + column_scols) + + if self._internal.column_index_names is None: + column_index_names = None else: - return kdf + # Manage column index names + level = column_index_level(column_index) + column_index_names = self._internal.column_index_names[-level:] + + internal = _InternalFrame(sdf=sdf, + index_map=self._internal.index_map, + column_index=column_index, + column_index_names=column_index_names) + kdf = DataFrame(internal) + except AnalysisException: + raise KeyError('[{}] don\'t exist in columns' + .format([col._jc.toString() for col in column_scols])) + + if returns_series: + return Series(kdf._internal.copy(scol=kdf._internal.column_scols[0]), + anchor=kdf) + else: + return kdf class LocIndexer(_LocIndexerLike): @@ -488,8 +464,7 @@ def _get_from_multiindex_column(self, key, indexes=None): if all(len(idx) > 0 and idx[0] == '' for _, idx in indexes): # If the head is '', drill down recursively. indexes = [(col, tuple([str(key), *idx[1:]])) for i, (col, idx) in enumerate(indexes)] - column_index, columns, returns_series = \ - self._get_from_multiindex_column((str(key),), indexes) + return self._get_from_multiindex_column((str(key),), indexes) else: returns_series = all(len(idx) == 0 for _, idx in indexes) if returns_series: @@ -497,12 +472,12 @@ def _get_from_multiindex_column(self, key, indexes=None): assert len(idxes) == 1 index = list(idxes)[0] column_index = [index] - columns = [self._internal.scol_for(index)] + column_scols = [self._internal.scol_for(index)] else: column_index = [idx for _, idx in indexes] - columns = [self._internal.scol_for(idx) for idx, _ in indexes] + column_scols = [self._internal.scol_for(idx) for idx, _ in indexes] - return column_index, columns, returns_series + return column_index, column_scols, returns_series def _select_cols(self, cols_sel): from databricks.koalas.series import Series @@ -520,21 +495,20 @@ def _select_cols(self, cols_sel): if cols_sel is None: column_index = self._internal.column_index - columns = self._internal.column_scols + column_scols = self._internal.column_scols elif isinstance(cols_sel, (str, tuple)): if isinstance(cols_sel, str): cols_sel = (cols_sel,) - column_index, columns, returns_series = \ - self._get_from_multiindex_column(cols_sel) + return self._get_from_multiindex_column(cols_sel) elif isinstance(cols_sel, spark.Column): - columns = [cols_sel] - column_index = None + column_index = [(self._internal.sdf.select(cols_sel).columns[0],)] + column_scols = [cols_sel] elif all(isinstance(key, Series) for key in cols_sel): - columns = [_make_col(key) for key in cols_sel] column_index = [key._internal.column_index[0] for key in cols_sel] + column_scols = [_make_col(key) for key in cols_sel] elif all(isinstance(key, spark.Column) for key in cols_sel): - columns = cols_sel - column_index = None + column_index = [(self._internal.sdf.select(col).columns[0],) for col in cols_sel] + column_scols = cols_sel elif (any(isinstance(key, str) for key in cols_sel) and any(isinstance(key, tuple) for key in cols_sel)): raise TypeError('Expected tuple, got str') @@ -546,19 +520,19 @@ def _select_cols(self, cols_sel): column_to_index = list(zip(self._internal.data_columns, self._internal.column_index)) - columns = [] column_index = [] + column_scols = [] for key in cols_sel: found = False for column, idx in column_to_index: if idx == key or idx[0] == key: - columns.append(_make_col(column)) column_index.append(idx) + column_scols.append(_make_col(column)) found = True if not found: raise KeyError("['{}'] not in index".format(name_like_string(key))) - return column_index, columns, returns_series + return column_index, column_scols, returns_series def __setitem__(self, key, value): from databricks.koalas.frame import DataFrame @@ -768,19 +742,19 @@ def _select_cols(self, cols_sel): # make cols_sel a 1-tuple of string if a single string if isinstance(cols_sel, Series) and cols_sel._equals(self._kdf_or_kser): - columns = cols_sel._internal.column_scols column_index = cols_sel._internal.column_index + column_scols = cols_sel._internal.column_scols elif isinstance(cols_sel, int): - columns = [self._internal.column_scols[cols_sel]] column_index = [self._internal.column_index[cols_sel]] + column_scols = [self._internal.column_scols[cols_sel]] elif cols_sel is None or cols_sel == slice(None): - columns = self._internal.column_scols column_index = self._internal.column_index + column_scols = self._internal.column_scols elif isinstance(cols_sel, slice): if all(s is None or isinstance(s, int) for s in (cols_sel.start, cols_sel.stop, cols_sel.step)): - columns = self._internal.column_scols[cols_sel] column_index = self._internal.column_index[cols_sel] + column_scols = self._internal.column_scols[cols_sel] else: not_none = cols_sel.start if cols_sel.start is not None \ else cols_sel.stop if cols_sel.stop is not None else cols_sel.step @@ -790,12 +764,12 @@ def _select_cols(self, cols_sel): if all(isinstance(s, bool) for s in cols_sel): cols_sel = [i for i, s in enumerate(cols_sel) if s] if all(isinstance(s, int) for s in cols_sel): - columns = [self._internal.column_scols[s] for s in cols_sel] column_index = [self._internal.column_index[s] for s in cols_sel] + column_scols = [self._internal.column_scols[s] for s in cols_sel] else: raise TypeError('cannot perform reduce with flexible type') else: raise ValueError("Location based indexing can only have [integer, integer slice, " "listlike of integers, boolean array] types, got {}".format(cols_sel)) - return column_index, columns, returns_series + return column_index, column_scols, returns_series