Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Index constructor take Series or Index objects. #2071

Merged
merged 5 commits into from
Mar 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 11 additions & 18 deletions databricks/koalas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
from databricks.koalas.config import get_option, option_context
from databricks.koalas.internal import (
InternalFrame,
DEFAULT_SERIES_NAME,
NATURAL_ORDER_COLUMN_NAME,
SPARK_DEFAULT_INDEX_NAME,
)
Expand Down Expand Up @@ -125,23 +124,17 @@ def align_diff_index_ops(func, this_index_ops: "IndexOpsMixin", *args) -> "Index

with option_context("compute.default_index_type", "distributed-sequence"):
if isinstance(this_index_ops, Index) and all(isinstance(col, Index) for col in cols):
return (
cast(
Series,
column_op(func)(
this_index_ops.to_series().reset_index(drop=True),
*[
arg.to_series().reset_index(drop=True)
if isinstance(arg, Index)
else arg
for arg in args
]
),
)
.sort_index()
.to_frame(DEFAULT_SERIES_NAME)
.set_index(DEFAULT_SERIES_NAME)
.index.rename(this_index_ops.name)
return Index(
column_op(func)(
this_index_ops.to_series().reset_index(drop=True),
*[
arg.to_series().reset_index(drop=True)
if isinstance(arg, Index)
else arg
for arg in args
]
).sort_index(),
name=this_index_ops.name,
)
elif isinstance(this_index_ops, Series):
this = this_index_ops.reset_index()
Expand Down
37 changes: 37 additions & 0 deletions databricks/koalas/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,49 @@ class Index(IndexOpsMixin):

>>> ks.Index(list('abc'))
Index(['a', 'b', 'c'], dtype='object')

From a Series:

>>> s = ks.Series([1, 2, 3], index=[10, 20, 30])
>>> ks.Index(s)
Int64Index([1, 2, 3], dtype='int64')

From an Index:

>>> idx = ks.Index([1, 2, 3])
>>> ks.Index(idx)
Int64Index([1, 2, 3], dtype='int64')
"""

def __new__(cls, data=None, dtype=None, copy=False, name=None, tupleize_cols=True, **kwargs):
if not is_hashable(name):
raise TypeError("Index.name must be a hashable type")

if isinstance(data, Series):
ueshin marked this conversation as resolved.
Show resolved Hide resolved
if dtype is not None:
data = data.astype(dtype)
if name is not None:
data = data.rename(name)

internal = InternalFrame(
spark_frame=data._internal.spark_frame,
index_spark_columns=data._internal.data_spark_columns,
index_names=data._internal.column_labels,
index_dtypes=data._internal.data_dtypes,
column_labels=[],
data_spark_columns=[],
data_dtypes=[],
)
return DataFrame(internal).index
elif isinstance(data, Index):
if copy:
data = data.copy()
if dtype is not None:
data = data.astype(dtype)
if name is not None:
data = data.rename(name)
return data

return ks.from_pandas(
pd.Index(
data=data, dtype=dtype, copy=copy, name=name, tupleize_cols=tupleize_cols, **kwargs
Expand Down
24 changes: 20 additions & 4 deletions databricks/koalas/indexes/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,13 @@
from databricks import koalas as ks
from databricks.koalas.indexes.base import Index
from databricks.koalas.missing.indexes import MissingPandasLikeDatetimeIndex
from databricks.koalas.series import Series


class DatetimeIndex(Index):
"""
Immutable ndarray-like of datetime64 data.

Represented internally as int64, and which can be boxed to Timestamp objects
that are subclasses of datetime and carry metadata.

Parameters
----------
data : array-like (1-dimensional), optional
Expand Down Expand Up @@ -63,7 +61,7 @@ class DatetimeIndex(Index):
If True, parse dates in `data` with the day first order.
yearfirst : bool, default False
If True parse dates in `data` with the year first order.
dtype : numpy.dtype or DatetimeTZDtype or str, default None
dtype : numpy.dtype or str, default None
Note that the only NumPy dtype allowed is ‘datetime64[ns]’.
copy : bool, default False
Make a copy of input ndarray.
Expand All @@ -79,6 +77,19 @@ class DatetimeIndex(Index):
--------
>>> ks.DatetimeIndex(['1970-01-01', '1970-01-01', '1970-01-01'])
DatetimeIndex(['1970-01-01', '1970-01-01', '1970-01-01'], dtype='datetime64[ns]', freq=None)

From a Series:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a great catch! :)


>>> from datetime import datetime
>>> s = ks.Series([datetime(2021, 3, 1), datetime(2021, 3, 2)], index=[10, 20])
>>> ks.DatetimeIndex(s)
DatetimeIndex(['2021-03-01', '2021-03-02'], dtype='datetime64[ns]', freq=None)

From an Index:

>>> idx = ks.DatetimeIndex(['1970-01-01', '1970-01-01', '1970-01-01'])
>>> ks.DatetimeIndex(idx)
DatetimeIndex(['1970-01-01', '1970-01-01', '1970-01-01'], dtype='datetime64[ns]', freq=None)
"""

def __new__(
Expand All @@ -97,6 +108,11 @@ def __new__(
if not is_hashable(name):
raise TypeError("Index.name must be a hashable type")

if isinstance(data, (Series, Index)):
if dtype is None:
dtype = "datetime64[ns]"
return Index(data, dtype=dtype, copy=copy, name=name)

kwargs = dict(
data=data,
normalize=normalize,
Expand Down
35 changes: 35 additions & 0 deletions databricks/koalas/indexes/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from databricks import koalas as ks
from databricks.koalas.indexes.base import Index
from databricks.koalas.series import Series


class NumericIndex(Index):
Expand Down Expand Up @@ -65,12 +66,29 @@ class Int64Index(IntegerIndex):
--------
>>> ks.Int64Index([1, 2, 3])
Int64Index([1, 2, 3], dtype='int64')

From a Series:

>>> s = ks.Series([1, 2, 3], index=[10, 20, 30])
>>> ks.Int64Index(s)
Int64Index([1, 2, 3], dtype='int64')

From an Index:

>>> idx = ks.Index([1, 2, 3])
>>> ks.Int64Index(idx)
Int64Index([1, 2, 3], dtype='int64')
"""

def __new__(cls, data=None, dtype=None, copy=False, name=None):
if not is_hashable(name):
raise TypeError("Index.name must be a hashable type")

if isinstance(data, (Series, Index)):
if dtype is None:
dtype = "int64"
return Index(data, dtype=dtype, copy=copy, name=name)

return ks.from_pandas(pd.Int64Index(data=data, dtype=dtype, copy=copy, name=name))


Expand Down Expand Up @@ -102,10 +120,27 @@ class Float64Index(NumericIndex):
--------
>>> ks.Float64Index([1.0, 2.0, 3.0])
Float64Index([1.0, 2.0, 3.0], dtype='float64')

From a Series:

>>> s = ks.Series([1, 2, 3], index=[10, 20, 30])
>>> ks.Float64Index(s)
Float64Index([1.0, 2.0, 3.0], dtype='float64')

From an Index:

>>> idx = ks.Index([1, 2, 3])
>>> ks.Float64Index(idx)
Float64Index([1.0, 2.0, 3.0], dtype='float64')
"""

def __new__(cls, data=None, dtype=None, copy=False, name=None):
if not is_hashable(name):
raise TypeError("Index.name must be a hashable type")

if isinstance(data, (Series, Index)):
if dtype is None:
dtype = "float64"
return Index(data, dtype=dtype, copy=copy, name=name)

return ks.from_pandas(pd.Float64Index(data=data, dtype=dtype, copy=copy, name=name))
40 changes: 39 additions & 1 deletion databricks/koalas/tests/test_indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def pdf(self):
def kdf(self):
return ks.from_pandas(self.pdf)

def test_index(self):
def test_index_basic(self):
for pdf in [
pd.DataFrame(np.random.randn(10, 5), index=np.random.randint(100, size=10)),
pd.DataFrame(
Expand All @@ -63,6 +63,44 @@ def test_index(self):
self.assert_eq(kdf.index, pdf.index)
self.assert_eq(type(kdf.index).__name__, type(pdf.index).__name__)

def test_index_from_series(self):
pser = pd.Series([1, 2, 3], name="a", index=[10, 20, 30])
kser = ks.from_pandas(pser)

self.assert_eq(ks.Index(kser), pd.Index(pser))
self.assert_eq(ks.Index(kser, dtype="float"), pd.Index(pser, dtype="float"))
self.assert_eq(ks.Index(kser, name="x"), pd.Index(pser, name="x"))

if LooseVersion(pd.__version__) >= LooseVersion("1.1"):
self.assert_eq(ks.Int64Index(kser), pd.Int64Index(pser))
self.assert_eq(ks.Float64Index(kser), pd.Float64Index(pser))
else:
self.assert_eq(ks.Int64Index(kser), pd.Int64Index(pser).rename("a"))
self.assert_eq(ks.Float64Index(kser), pd.Float64Index(pser).rename("a"))

pser = pd.Series([datetime(2021, 3, 1), datetime(2021, 3, 2)], name="x", index=[10, 20])
kser = ks.from_pandas(pser)

self.assert_eq(ks.Index(kser), pd.Index(pser))
self.assert_eq(ks.DatetimeIndex(kser), pd.DatetimeIndex(pser))

def test_index_from_index(self):
pidx = pd.Index([1, 2, 3], name="a")
kidx = ks.from_pandas(pidx)

self.assert_eq(ks.Index(kidx), pd.Index(pidx))
self.assert_eq(ks.Index(kidx, dtype="float"), pd.Index(pidx, dtype="float"))
self.assert_eq(ks.Index(kidx, name="x"), pd.Index(pidx, name="x"))

self.assert_eq(ks.Int64Index(kidx), pd.Int64Index(pidx))
self.assert_eq(ks.Float64Index(kidx), pd.Float64Index(pidx))

pidx = pd.DatetimeIndex(["2021-03-01", "2021-03-02"])
kidx = ks.from_pandas(pidx)

self.assert_eq(ks.Index(kidx), pd.Index(pidx))
self.assert_eq(ks.DatetimeIndex(kidx), pd.DatetimeIndex(pidx))

def test_index_getattr(self):
kidx = self.kdf.index
item = "databricks"
Expand Down