@@ -184,7 +184,60 @@ def __getitem__(self, key):
184
184
or len (values ) > 1 ) else values [0 ]
185
185
186
186
187
- class LocIndexer (_IndexerLike ):
187
+ class _LocIndexerLike (_IndexerLike ):
188
+
189
+ def __getitem__ (self , key ):
190
+ from databricks .koalas .frame import DataFrame
191
+ from databricks .koalas .series import Series
192
+
193
+ rows_sel , cols_sel = _unfold (key , self ._kdf_or_kser if self ._is_series else None )
194
+
195
+ cond , limit = self ._select_rows (rows_sel )
196
+ column_index , columns , returns_series = self ._select_cols (cols_sel )
197
+
198
+ if cond is None and limit is None and returns_series :
199
+ if self ._is_series :
200
+ return self ._kdf_or_kser ._with_new_scol (columns [0 ])
201
+ else :
202
+ return Series (self ._internal .copy (scol = columns [0 ], column_index = [column_index [0 ]]),
203
+ anchor = self ._kdf_or_kser )
204
+ else :
205
+ try :
206
+ sdf = self ._internal ._sdf
207
+ if cond is not None :
208
+ sdf = sdf .where (cond )
209
+ if limit is not None :
210
+ if limit >= 0 :
211
+ sdf = sdf .limit (limit )
212
+ else :
213
+ sdf = sdf .limit (sdf .count () + limit )
214
+
215
+ sdf = sdf .select (self ._internal .index_scols + columns )
216
+
217
+ if self ._internal .column_index_names is None :
218
+ column_index_names = None
219
+ else :
220
+ # Manage column index names
221
+ level = column_index_level (column_index )
222
+ column_index_names = self ._internal .column_index_names [- level :]
223
+
224
+ internal = _InternalFrame (sdf = sdf ,
225
+ index_map = self ._internal .index_map ,
226
+ column_index = column_index ,
227
+ column_index_names = column_index_names )
228
+ kdf = DataFrame (internal )
229
+ except AnalysisException :
230
+ raise KeyError ('[{}] don\' t exist in columns'
231
+ .format ([col ._jc .toString () for col in columns ]))
232
+
233
+ if returns_series :
234
+ return Series (kdf ._internal .copy (scol = kdf ._internal .column_scols [0 ]),
235
+ anchor = kdf )
236
+ else :
237
+ return kdf
238
+
239
+
240
+ class LocIndexer (_LocIndexerLike ):
188
241
"""
189
242
Access a group of rows and columns by label(s) or a boolean Series.
190
243
@@ -372,14 +425,14 @@ def _select_rows(self, rows_sel):
372
425
373
426
if isinstance (rows_sel , Series ):
374
427
assert isinstance (rows_sel .spark_type , BooleanType ), rows_sel .spark_type
375
- return rows_sel ._scol
428
+ return rows_sel ._scol , None
376
429
elif isinstance (rows_sel , slice ):
377
430
assert len (self ._internal .index_columns ) > 0
378
431
if rows_sel .step is not None :
379
432
LocIndexer ._raiseNotImplemented ("Cannot use step with Spark." )
380
433
if rows_sel == slice (None ):
381
434
# If slice is None - select everything, so nothing to do
382
- return None
435
+ return None , None
383
436
elif len (self ._internal .index_columns ) == 1 :
384
437
start = rows_sel .start
385
438
stop = rows_sel .stop
@@ -393,7 +446,7 @@ def _select_rows(self, rows_sel):
393
446
cond .append (index_column ._scol <= F .lit (stop ).cast (index_data_type ))
394
447
395
448
if len (cond ) > 0 :
396
- return reduce (lambda x , y : x & y , cond )
449
+ return reduce (lambda x , y : x & y , cond ), None
397
450
else :
398
451
LocIndexer ._raiseNotImplemented ("Cannot use slice for MultiIndex with Spark." )
399
452
elif isinstance (rows_sel , str ):
@@ -406,15 +459,15 @@ def _select_rows(self, rows_sel):
406
459
LocIndexer ._raiseNotImplemented (
407
460
"Cannot use a scalar value for row selection with Spark." )
408
461
if len (rows_sel ) == 0 :
409
- return F .lit (False )
462
+ return F .lit (False ), None
410
463
elif len (self ._internal .index_columns ) == 1 :
411
464
index_column = self ._kdf_or_kser .index .to_series ()
412
465
index_data_type = index_column .spark_type
413
466
if len (rows_sel ) == 1 :
414
- return index_column ._scol == F .lit (rows_sel [0 ]).cast (index_data_type )
467
+ return index_column ._scol == F .lit (rows_sel [0 ]).cast (index_data_type ), None
415
468
else :
416
469
return index_column ._scol .isin (
417
- [F .lit (r ).cast (index_data_type ) for r in rows_sel ])
470
+ [F .lit (r ).cast (index_data_type ) for r in rows_sel ]), None
418
471
else :
419
472
LocIndexer ._raiseNotImplemented ("Cannot select with MultiIndex with Spark." )
420
473
@@ -451,14 +504,9 @@ def _get_from_multiindex_column(self, key, indexes=None):
451
504
452
505
return column_index , columns , returns_series
453
506
454
- def __getitem__ (self , key ):
455
- from databricks .koalas .frame import DataFrame
507
+ def _select_cols (self , cols_sel ):
456
508
from databricks .koalas .series import Series
457
509
458
- rows_sel , cols_sel = _unfold (key , self ._kdf_or_kser if self ._is_series else None )
459
-
460
- cond = self ._select_rows (rows_sel )
461
-
462
510
# make cols_sel a 1-tuple of string if a single string
463
511
if isinstance (cols_sel , Series ):
464
512
cols_sel = _make_col (cols_sel )
@@ -469,6 +517,7 @@ def __getitem__(self, key):
469
517
cols_sel = None
470
518
471
519
returns_series = cols_sel is not None and isinstance (cols_sel , spark .Column )
520
+
472
521
if cols_sel is None :
473
522
column_index = self ._internal .column_index
474
523
columns = self ._internal .column_scols
@@ -507,43 +556,9 @@ def __getitem__(self, key):
507
556
column_index .append (idx )
508
557
found = True
509
558
if not found :
510
- raise KeyError ("['{}'] not in index" .format (key ))
511
-
512
- if cond is None and returns_series :
513
- if self ._is_series :
514
- return self ._kdf_or_kser ._with_new_scol (columns [0 ])
515
- else :
516
- return Series (self ._internal .copy (scol = columns [0 ], column_index = [column_index [0 ]]),
517
- anchor = self ._kdf_or_kser )
518
- else :
519
- try :
520
- sdf = self ._internal ._sdf
521
- if cond is not None :
522
- sdf = sdf .where (cond )
523
-
524
- sdf = sdf .select (self ._internal .index_scols + columns )
525
-
526
- if self ._internal .column_index_names is None :
527
- column_index_names = None
528
- else :
529
- # Manage column index names
530
- level = column_index_level (column_index )
531
- column_index_names = self ._internal .column_index_names [- level :]
532
-
533
- internal = _InternalFrame (sdf = sdf ,
534
- index_map = self ._internal .index_map ,
535
- column_index = column_index ,
536
- column_index_names = column_index_names )
537
- kdf = DataFrame (internal )
538
- except AnalysisException :
539
- raise KeyError ('[{}] don\' t exist in columns'
540
- .format ([col ._jc .toString () for col in columns ]))
559
+ raise KeyError ("['{}'] not in index" .format (name_like_string (key )))
541
560
542
- if returns_series :
543
- return Series (kdf ._internal .copy (scol = kdf ._internal .column_scols [0 ]),
544
- anchor = kdf )
545
- else :
546
- return kdf
561
+ return column_index , columns , returns_series
547
562
548
563
def __setitem__ (self , key , value ):
549
564
from databricks .koalas .frame import DataFrame
@@ -604,7 +619,7 @@ def __setitem__(self, key, value):
604
619
self ._kdf_or_kser [col_sel ] = value
605
620
606
621
607
- class ILocIndexer (_IndexerLike ):
622
+ class ILocIndexer (_LocIndexerLike ):
608
623
"""
609
624
Purely integer-location based indexing for selection by position.
610
625
@@ -746,21 +761,10 @@ def _select_rows(self, rows_sel):
746
761
ILocIndexer ._raiseNotImplemented (".iloc requires numeric slice or conditional "
747
762
"boolean Index, got {}" .format (rows_sel ))
748
763
749
- def __getitem__ (self , key ):
750
- from databricks .koalas .frame import DataFrame
764
+ def _select_cols (self , cols_sel ):
751
765
from databricks .koalas .series import Series
752
766
753
- rows_sel , cols_sel = _unfold (key , self ._kdf_or_kser if self ._is_series else None )
754
-
755
- sdf = self ._internal .sdf
756
- cond , limit = self ._select_rows (rows_sel )
757
- if cond is not None :
758
- sdf = sdf .where (cond )
759
- if limit is not None :
760
- if limit >= 0 :
761
- sdf = sdf .limit (limit )
762
- else :
763
- sdf = sdf .limit (sdf .count () + limit )
767
+ returns_series = cols_sel is not None and isinstance (cols_sel , (Series , int ))
764
768
765
769
# make cols_sel a 1-tuple of string if a single string
766
770
if isinstance (cols_sel , Series ) and cols_sel ._equals (self ._kdf_or_kser ):
@@ -794,19 +798,4 @@ def __getitem__(self, key):
794
798
raise ValueError ("Location based indexing can only have [integer, integer slice, "
795
799
"listlike of integers, boolean array] types, got {}" .format (cols_sel ))
796
800
797
- try :
798
- sdf = sdf .select (self ._internal .index_scols + columns )
799
- internal = _InternalFrame (sdf = sdf ,
800
- index_map = self ._internal .index_map ,
801
- column_index = column_index ,
802
- column_index_names = self ._internal .column_index_names )
803
- kdf = DataFrame (internal )
804
- except AnalysisException :
805
- raise KeyError ('[{}] don\' t exist in columns'
806
- .format ([col ._jc .toString () for col in columns ]))
807
-
808
- if cols_sel is not None and isinstance (cols_sel , (Series , int )):
809
- from databricks .koalas .series import _col
810
- return _col (kdf )
811
- else :
812
- return kdf
801
+ return column_index , columns , returns_series
0 commit comments