Skip to content

Commit 0541a5d

Browse files
authored
Introduce _LocIndexerLike and consolidate some logic. (#1149)
1 parent 7f59f81 commit 0541a5d

File tree

1 file changed

+68
-79
lines changed

1 file changed

+68
-79
lines changed

databricks/koalas/indexing.py

+68-79
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,60 @@ def __getitem__(self, key):
184184
or len(values) > 1) else values[0]
185185

186186

187-
class LocIndexer(_IndexerLike):
187+
class _LocIndexerLike(_IndexerLike):
188+
189+
def __getitem__(self, key):
190+
from databricks.koalas.frame import DataFrame
191+
from databricks.koalas.series import Series
192+
193+
rows_sel, cols_sel = _unfold(key, self._kdf_or_kser if self._is_series else None)
194+
195+
cond, limit = self._select_rows(rows_sel)
196+
column_index, columns, returns_series = self._select_cols(cols_sel)
197+
198+
if cond is None and limit is None and returns_series:
199+
if self._is_series:
200+
return self._kdf_or_kser._with_new_scol(columns[0])
201+
else:
202+
return Series(self._internal.copy(scol=columns[0], column_index=[column_index[0]]),
203+
anchor=self._kdf_or_kser)
204+
else:
205+
try:
206+
sdf = self._internal._sdf
207+
if cond is not None:
208+
sdf = sdf.where(cond)
209+
if limit is not None:
210+
if limit >= 0:
211+
sdf = sdf.limit(limit)
212+
else:
213+
sdf = sdf.limit(sdf.count() + limit)
214+
215+
sdf = sdf.select(self._internal.index_scols + columns)
216+
217+
if self._internal.column_index_names is None:
218+
column_index_names = None
219+
else:
220+
# Manage column index names
221+
level = column_index_level(column_index)
222+
column_index_names = self._internal.column_index_names[-level:]
223+
224+
internal = _InternalFrame(sdf=sdf,
225+
index_map=self._internal.index_map,
226+
column_index=column_index,
227+
column_index_names=column_index_names)
228+
kdf = DataFrame(internal)
229+
except AnalysisException:
230+
raise KeyError('[{}] don\'t exist in columns'
231+
.format([col._jc.toString() for col in columns]))
232+
233+
if returns_series:
234+
return Series(kdf._internal.copy(scol=kdf._internal.column_scols[0]),
235+
anchor=kdf)
236+
else:
237+
return kdf
238+
239+
240+
class LocIndexer(_LocIndexerLike):
188241
"""
189242
Access a group of rows and columns by label(s) or a boolean Series.
190243
@@ -372,14 +425,14 @@ def _select_rows(self, rows_sel):
372425

373426
if isinstance(rows_sel, Series):
374427
assert isinstance(rows_sel.spark_type, BooleanType), rows_sel.spark_type
375-
return rows_sel._scol
428+
return rows_sel._scol, None
376429
elif isinstance(rows_sel, slice):
377430
assert len(self._internal.index_columns) > 0
378431
if rows_sel.step is not None:
379432
LocIndexer._raiseNotImplemented("Cannot use step with Spark.")
380433
if rows_sel == slice(None):
381434
# If slice is None - select everything, so nothing to do
382-
return None
435+
return None, None
383436
elif len(self._internal.index_columns) == 1:
384437
start = rows_sel.start
385438
stop = rows_sel.stop
@@ -393,7 +446,7 @@ def _select_rows(self, rows_sel):
393446
cond.append(index_column._scol <= F.lit(stop).cast(index_data_type))
394447

395448
if len(cond) > 0:
396-
return reduce(lambda x, y: x & y, cond)
449+
return reduce(lambda x, y: x & y, cond), None
397450
else:
398451
LocIndexer._raiseNotImplemented("Cannot use slice for MultiIndex with Spark.")
399452
elif isinstance(rows_sel, str):
@@ -406,15 +459,15 @@ def _select_rows(self, rows_sel):
406459
LocIndexer._raiseNotImplemented(
407460
"Cannot use a scalar value for row selection with Spark.")
408461
if len(rows_sel) == 0:
409-
return F.lit(False)
462+
return F.lit(False), None
410463
elif len(self._internal.index_columns) == 1:
411464
index_column = self._kdf_or_kser.index.to_series()
412465
index_data_type = index_column.spark_type
413466
if len(rows_sel) == 1:
414-
return index_column._scol == F.lit(rows_sel[0]).cast(index_data_type)
467+
return index_column._scol == F.lit(rows_sel[0]).cast(index_data_type), None
415468
else:
416469
return index_column._scol.isin(
417-
[F.lit(r).cast(index_data_type) for r in rows_sel])
470+
[F.lit(r).cast(index_data_type) for r in rows_sel]), None
418471
else:
419472
LocIndexer._raiseNotImplemented("Cannot select with MultiIndex with Spark.")
420473

@@ -451,14 +504,9 @@ def _get_from_multiindex_column(self, key, indexes=None):
451504

452505
return column_index, columns, returns_series
453506

454-
def __getitem__(self, key):
455-
from databricks.koalas.frame import DataFrame
507+
def _select_cols(self, cols_sel):
456508
from databricks.koalas.series import Series
457509

458-
rows_sel, cols_sel = _unfold(key, self._kdf_or_kser if self._is_series else None)
459-
460-
cond = self._select_rows(rows_sel)
461-
462510
# make cols_sel a 1-tuple of string if a single string
463511
if isinstance(cols_sel, Series):
464512
cols_sel = _make_col(cols_sel)
@@ -469,6 +517,7 @@ def __getitem__(self, key):
469517
cols_sel = None
470518

471519
returns_series = cols_sel is not None and isinstance(cols_sel, spark.Column)
520+
472521
if cols_sel is None:
473522
column_index = self._internal.column_index
474523
columns = self._internal.column_scols
@@ -507,43 +556,9 @@ def __getitem__(self, key):
507556
column_index.append(idx)
508557
found = True
509558
if not found:
510-
raise KeyError("['{}'] not in index".format(key))
511-
512-
if cond is None and returns_series:
513-
if self._is_series:
514-
return self._kdf_or_kser._with_new_scol(columns[0])
515-
else:
516-
return Series(self._internal.copy(scol=columns[0], column_index=[column_index[0]]),
517-
anchor=self._kdf_or_kser)
518-
else:
519-
try:
520-
sdf = self._internal._sdf
521-
if cond is not None:
522-
sdf = sdf.where(cond)
523-
524-
sdf = sdf.select(self._internal.index_scols + columns)
525-
526-
if self._internal.column_index_names is None:
527-
column_index_names = None
528-
else:
529-
# Manage column index names
530-
level = column_index_level(column_index)
531-
column_index_names = self._internal.column_index_names[-level:]
532-
533-
internal = _InternalFrame(sdf=sdf,
534-
index_map=self._internal.index_map,
535-
column_index=column_index,
536-
column_index_names=column_index_names)
537-
kdf = DataFrame(internal)
538-
except AnalysisException:
539-
raise KeyError('[{}] don\'t exist in columns'
540-
.format([col._jc.toString() for col in columns]))
559+
raise KeyError("['{}'] not in index".format(name_like_string(key)))
541560

542-
if returns_series:
543-
return Series(kdf._internal.copy(scol=kdf._internal.column_scols[0]),
544-
anchor=kdf)
545-
else:
546-
return kdf
561+
return column_index, columns, returns_series
547562

548563
def __setitem__(self, key, value):
549564
from databricks.koalas.frame import DataFrame
@@ -604,7 +619,7 @@ def __setitem__(self, key, value):
604619
self._kdf_or_kser[col_sel] = value
605620

606621

607-
class ILocIndexer(_IndexerLike):
622+
class ILocIndexer(_LocIndexerLike):
608623
"""
609624
Purely integer-location based indexing for selection by position.
610625
@@ -746,21 +761,10 @@ def _select_rows(self, rows_sel):
746761
ILocIndexer._raiseNotImplemented(".iloc requires numeric slice or conditional "
747762
"boolean Index, got {}".format(rows_sel))
748763

749-
def __getitem__(self, key):
750-
from databricks.koalas.frame import DataFrame
764+
def _select_cols(self, cols_sel):
751765
from databricks.koalas.series import Series
752766

753-
rows_sel, cols_sel = _unfold(key, self._kdf_or_kser if self._is_series else None)
754-
755-
sdf = self._internal.sdf
756-
cond, limit = self._select_rows(rows_sel)
757-
if cond is not None:
758-
sdf = sdf.where(cond)
759-
if limit is not None:
760-
if limit >= 0:
761-
sdf = sdf.limit(limit)
762-
else:
763-
sdf = sdf.limit(sdf.count() + limit)
767+
returns_series = cols_sel is not None and isinstance(cols_sel, (Series, int))
764768

765769
# make cols_sel a 1-tuple of string if a single string
766770
if isinstance(cols_sel, Series) and cols_sel._equals(self._kdf_or_kser):
@@ -794,19 +798,4 @@ def __getitem__(self, key):
794798
raise ValueError("Location based indexing can only have [integer, integer slice, "
795799
"listlike of integers, boolean array] types, got {}".format(cols_sel))
796800

797-
try:
798-
sdf = sdf.select(self._internal.index_scols + columns)
799-
internal = _InternalFrame(sdf=sdf,
800-
index_map=self._internal.index_map,
801-
column_index=column_index,
802-
column_index_names=self._internal.column_index_names)
803-
kdf = DataFrame(internal)
804-
except AnalysisException:
805-
raise KeyError('[{}] don\'t exist in columns'
806-
.format([col._jc.toString() for col in columns]))
807-
808-
if cols_sel is not None and isinstance(cols_sel, (Series, int)):
809-
from databricks.koalas.series import _col
810-
return _col(kdf)
811-
else:
812-
return kdf
801+
return column_index, columns, returns_series

0 commit comments

Comments
 (0)