Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…first_valid_index
  • Loading branch information
itholic committed Oct 24, 2019
2 parents a70b92c + c8dcb64 commit bad5c3f
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 1 deletion.
1 change: 0 additions & 1 deletion databricks/koalas/missing/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ class _MissingPandasLikeSeries(object):
update = unsupported_function('update')
view = unsupported_function('view')
where = unsupported_function('where')
xs = unsupported_function('xs')

# Deprecated functions
as_blocks = unsupported_function('as_blocks', deprecated=True)
Expand Down
96 changes: 96 additions & 0 deletions databricks/koalas/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3661,6 +3661,102 @@ def replace(self, to_replace=None, value=None, regex=False) -> 'Series':

return self._with_new_scol(current)

def xs(self, key, level=None):
"""
Return cross-section from the Series.
This method takes a `key` argument to select data at a particular
level of a MultiIndex.
Parameters
----------
key : label or tuple of label
Label contained in the index, or partially in a MultiIndex.
level : object, defaults to first n levels (n=1 or len(key))
In case of a key partially contained in a MultiIndex, indicate
which levels are used. Levels can be referred by label or position.
Returns
-------
Series
Cross-section from the original Series
corresponding to the selected index levels.
Examples
--------
>>> midx = pd.MultiIndex([['a', 'b', 'c'],
... ['lama', 'cow', 'falcon'],
... ['speed', 'weight', 'length']],
... [[0, 0, 0, 1, 1, 1, 2, 2, 2],
... [0, 0, 0, 1, 1, 1, 2, 2, 2],
... [0, 1, 2, 0, 1, 2, 0, 1, 2]])
>>> s = ks.Series([45, 200, 1.2, 30, 250, 1.5, 320, 1, 0.3],
... index=midx)
>>> s
a lama speed 45.0
weight 200.0
length 1.2
b cow speed 30.0
weight 250.0
length 1.5
c falcon speed 320.0
weight 1.0
length 0.3
Name: 0, dtype: float64
Get values at specified index
>>> s.xs('a')
lama speed 45.0
weight 200.0
length 1.2
Name: 0, dtype: float64
Get values at several indexes
>>> s.xs(('a', 'lama'))
speed 45.0
weight 200.0
length 1.2
Name: 0, dtype: float64
Get values at specified index and level
>>> s.xs('lama', level=1)
a speed 45.0
weight 200.0
length 1.2
Name: 0, dtype: float64
"""
if not isinstance(key, tuple):
key = (key,)
if level is None:
level = 0

cols = (self._internal.index_scols[:level] +
self._internal.index_scols[level+len(key):] +
[self._internal.scol_for(self._internal.column_index[0])])
rows = [self._internal.scols[lvl] == index
for lvl, index in enumerate(key, level)]
sdf = self._internal.sdf \
.select(cols) \
.where(reduce(lambda x, y: x & y, rows))

if len(self._internal._index_map) == len(key):
# if sdf has one column and one data, return data only without frame
pdf = sdf.limit(2).toPandas()
length = len(pdf)
if length == 1:
return pdf[self.name].iloc[0]

index_cols = [col for col in sdf.columns if col not in self._internal.data_columns]
index_map_dict = dict(self._internal.index_map)
internal = self._internal.copy(
sdf=sdf,
index_map=[(index_col, index_map_dict[index_col]) for index_col in index_cols])

return _col(DataFrame(internal))

def _cum(self, func, skipna, part_cols=()):
# This is used to cummin, cummax, cumsum, etc.
index_columns = self._internal.index_columns
Expand Down
13 changes: 13 additions & 0 deletions databricks/koalas/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,19 @@ def test_replace(self):
with self.assertRaisesRegex(NotImplementedError, msg):
kser.replace(r'^1.$', regex=True)

def test_xs(self):
midx = pd.MultiIndex([['a', 'b', 'c'],
['lama', 'cow', 'falcon'],
['speed', 'weight', 'length']],
[[0, 0, 0, 1, 1, 1, 2, 2, 2],
[0, 0, 0, 1, 1, 1, 2, 2, 2],
[0, 1, 2, 0, 1, 2, 0, 1, 2]])
kser = ks.Series([45, 200, 1.2, 30, 250, 1.5, 320, 1, 0.3],
index=midx)
pser = kser.to_pandas()

self.assert_eq(kser.xs(('a', 'lama', 'speed')), pser.xs(('a', 'lama', 'speed')))

def test_duplicates(self):
# test on texts
pser = pd.Series(['lama', 'cow', 'lama', 'beetle', 'lama', 'hippo'],
Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/series.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ Indexing, iteration
Series.loc
Series.iloc
Series.keys
Series.xs

Binary operator functions
-------------------------
Expand Down

0 comments on commit bad5c3f

Please sign in to comment.