Skip to content

Commit

Permalink
Introduce koalas_dtype to consolidate the logic to convert to dtype a…
Browse files Browse the repository at this point in the history
…nd Spark DataType. (#2120)

Introduce `koalas_dtype` to consolidate the logic to convert to dtype and Spark DataType.
  • Loading branch information
ueshin authored Mar 26, 2021
1 parent 0d2cef0 commit ed0bd49
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 69 deletions.
6 changes: 3 additions & 3 deletions databricks/koalas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

import numpy as np
import pandas as pd # noqa: F401
from pandas.api.types import is_list_like, pandas_dtype, CategoricalDtype
from pandas.api.types import is_list_like, CategoricalDtype
from pyspark import sql as spark
from pyspark.sql import functions as F, Window, Column
from pyspark.sql.types import (
Expand Down Expand Up @@ -55,6 +55,7 @@
Dtype,
as_spark_type,
extension_dtypes,
koalas_dtype,
spark_type_to_pandas_dtype,
)
from databricks.koalas.utils import (
Expand Down Expand Up @@ -1052,8 +1053,7 @@ def astype(self, dtype: Union[str, type, Dtype]) -> Union["Index", "Series"]:
>>> ser.rename("a").to_frame().set_index("a").index.astype('int64')
Int64Index([1, 2], dtype='int64', name='a')
"""
dtype = pandas_dtype(dtype)
spark_type = as_spark_type(dtype)
dtype, spark_type = koalas_dtype(dtype)
if not spark_type:
raise ValueError("Type {} not understood".format(dtype))

Expand Down
121 changes: 69 additions & 52 deletions databricks/koalas/tests/test_typedef.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,12 @@
)

from databricks.koalas.typedef import (
infer_return_type,
as_spark_type,
extension_dtypes_available,
extension_float_dtypes_available,
extension_object_dtypes_available,
infer_return_type,
koalas_dtype,
)
from databricks import koalas as ks

Expand Down Expand Up @@ -265,75 +266,88 @@ def f() -> ks.Series[pdf.a.dtype]: # type: ignore

self.assertRaisesRegex(TypeError, "object.*not understood", try_infer_return_type)

def test_as_spark_type(self):
def test_as_spark_type_koalas_dtype(self):
type_mapper = {
# binary
np.character: BinaryType(),
np.bytes_: BinaryType(),
np.string_: BinaryType(),
bytes: BinaryType(),
np.character: (np.character, BinaryType()),
np.bytes_: (np.bytes_, BinaryType()),
np.string_: (np.bytes_, BinaryType()),
bytes: (np.bytes_, BinaryType()),
# integer
np.int8: ByteType(),
np.byte: ByteType(),
np.int16: ShortType(),
np.int32: IntegerType(),
np.int64: LongType(),
np.int: LongType(),
int: LongType(),
np.int8: (np.int8, ByteType()),
np.byte: (np.int8, ByteType()),
np.int16: (np.int16, ShortType()),
np.int32: (np.int32, IntegerType()),
np.int64: (np.int64, LongType()),
np.int: (np.int64, LongType()),
int: (np.int64, LongType()),
# floating
np.float32: FloatType(),
np.float: DoubleType(),
np.float64: DoubleType(),
float: DoubleType(),
np.float32: (np.float32, FloatType()),
np.float: (np.float64, DoubleType()),
np.float64: (np.float64, DoubleType()),
float: (np.float64, DoubleType()),
# string
np.str: StringType(),
np.unicode_: StringType(),
str: StringType(),
np.str: (np.unicode_, StringType()),
np.unicode_: (np.unicode_, StringType()),
str: (np.unicode_, StringType()),
# bool
np.bool: BooleanType(),
bool: BooleanType(),
np.bool: (np.bool, BooleanType()),
bool: (np.bool, BooleanType()),
# datetime
np.datetime64: TimestampType(),
datetime.datetime: TimestampType(),
np.datetime64: (np.datetime64, TimestampType()),
datetime.datetime: (np.dtype("datetime64[ns]"), TimestampType()),
# DateType
datetime.date: DateType(),
datetime.date: (np.dtype("object"), DateType()),
# DecimalType
decimal.Decimal: DecimalType(38, 18),
decimal.Decimal: (np.dtype("object"), DecimalType(38, 18)),
# ArrayType
np.ndarray: ArrayType(StringType()),
List[bytes]: ArrayType(BinaryType()),
List[np.character]: ArrayType(BinaryType()),
List[np.bytes_]: ArrayType(BinaryType()),
List[np.string_]: ArrayType(BinaryType()),
List[bool]: ArrayType(BooleanType()),
List[np.bool]: ArrayType(BooleanType()),
List[datetime.date]: ArrayType(DateType()),
List[np.int8]: ArrayType(ByteType()),
List[np.byte]: ArrayType(ByteType()),
List[decimal.Decimal]: ArrayType(DecimalType(38, 18)),
List[float]: ArrayType(DoubleType()),
List[np.float]: ArrayType(DoubleType()),
List[np.float64]: ArrayType(DoubleType()),
List[np.float32]: ArrayType(FloatType()),
List[np.int32]: ArrayType(IntegerType()),
List[int]: ArrayType(LongType()),
List[np.int]: ArrayType(LongType()),
List[np.int64]: ArrayType(LongType()),
List[np.int16]: ArrayType(ShortType()),
List[str]: ArrayType(StringType()),
List[np.unicode_]: ArrayType(StringType()),
List[datetime.datetime]: ArrayType(TimestampType()),
List[np.datetime64]: ArrayType(TimestampType()),
np.ndarray: (np.dtype("object"), ArrayType(StringType())),
List[bytes]: (np.dtype("object"), ArrayType(BinaryType())),
List[np.character]: (np.dtype("object"), ArrayType(BinaryType())),
List[np.bytes_]: (np.dtype("object"), ArrayType(BinaryType())),
List[np.string_]: (np.dtype("object"), ArrayType(BinaryType())),
List[bool]: (np.dtype("object"), ArrayType(BooleanType())),
List[np.bool]: (np.dtype("object"), ArrayType(BooleanType())),
List[datetime.date]: (np.dtype("object"), ArrayType(DateType())),
List[np.int8]: (np.dtype("object"), ArrayType(ByteType())),
List[np.byte]: (np.dtype("object"), ArrayType(ByteType())),
List[decimal.Decimal]: (np.dtype("object"), ArrayType(DecimalType(38, 18))),
List[float]: (np.dtype("object"), ArrayType(DoubleType())),
List[np.float]: (np.dtype("object"), ArrayType(DoubleType())),
List[np.float64]: (np.dtype("object"), ArrayType(DoubleType())),
List[np.float32]: (np.dtype("object"), ArrayType(FloatType())),
List[np.int32]: (np.dtype("object"), ArrayType(IntegerType())),
List[int]: (np.dtype("object"), ArrayType(LongType())),
List[np.int]: (np.dtype("object"), ArrayType(LongType())),
List[np.int64]: (np.dtype("object"), ArrayType(LongType())),
List[np.int16]: (np.dtype("object"), ArrayType(ShortType())),
List[str]: (np.dtype("object"), ArrayType(StringType())),
List[np.unicode_]: (np.dtype("object"), ArrayType(StringType())),
List[datetime.datetime]: (np.dtype("object"), ArrayType(TimestampType())),
List[np.datetime64]: (np.dtype("object"), ArrayType(TimestampType())),
# CategoricalDtype
CategoricalDtype(categories=["a", "b", "c"]): LongType(),
CategoricalDtype(categories=["a", "b", "c"]): (
CategoricalDtype(categories=["a", "b", "c"]),
LongType(),
),
}

for numpy_or_python_type, spark_type in type_mapper.items():
for numpy_or_python_type, (dtype, spark_type) in type_mapper.items():
self.assertEqual(as_spark_type(numpy_or_python_type), spark_type)
self.assertEqual(koalas_dtype(numpy_or_python_type), (dtype, spark_type))

with self.assertRaisesRegex(TypeError, "Type uint64 was not understood."):
as_spark_type(np.dtype("uint64"))

with self.assertRaisesRegex(TypeError, "Type object was not understood."):
as_spark_type(np.dtype("object"))

with self.assertRaisesRegex(TypeError, "Type uint64 was not understood."):
koalas_dtype(np.dtype("uint64"))

with self.assertRaisesRegex(TypeError, "Type object was not understood."):
koalas_dtype(np.dtype("object"))

@unittest.skipIf(not extension_dtypes_available, "The pandas extension types are not available")
def test_as_spark_type_extension_dtypes(self):
from pandas import Int8Dtype, Int16Dtype, Int32Dtype, Int64Dtype
Expand All @@ -347,6 +361,7 @@ def test_as_spark_type_extension_dtypes(self):

for extension_dtype, spark_type in type_mapper.items():
self.assertEqual(as_spark_type(extension_dtype), spark_type)
self.assertEqual(koalas_dtype(extension_dtype), (extension_dtype, spark_type))

@unittest.skipIf(
not extension_object_dtypes_available, "The pandas extension object types are not available"
Expand All @@ -361,6 +376,7 @@ def test_as_spark_type_extension_object_dtypes(self):

for extension_dtype, spark_type in type_mapper.items():
self.assertEqual(as_spark_type(extension_dtype), spark_type)
self.assertEqual(koalas_dtype(extension_dtype), (extension_dtype, spark_type))

@unittest.skipIf(
not extension_float_dtypes_available, "The pandas extension float types are not available"
Expand All @@ -375,3 +391,4 @@ def test_as_spark_type_extension_float_dtypes(self):

for extension_dtype, spark_type in type_mapper.items():
self.assertEqual(as_spark_type(extension_dtype), spark_type)
self.assertEqual(koalas_dtype(extension_dtype), (extension_dtype, spark_type))
74 changes: 60 additions & 14 deletions databricks/koalas/typedef/typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,21 @@
"""
Utilities to deal with types. This is mostly focused on python3.
"""
import typing
import datetime
import decimal
from inspect import getfullargspec
from typing import Generic, List, Tuple, TypeVar, Union # noqa: F401

import numpy as np
import pandas as pd
from pandas.api.types import CategoricalDtype
from pandas.api.types import CategoricalDtype, pandas_dtype
from pandas.api.extensions import ExtensionDtype

try:
from pandas import Int8Dtype, Int16Dtype, Int32Dtype, Int64Dtype

extension_dtypes_available = True
extension_dtypes = (Int8Dtype, Int16Dtype, Int32Dtype, Int64Dtype) # type: typing.Tuple
extension_dtypes = (Int8Dtype, Int16Dtype, Int32Dtype, Int64Dtype) # type: Tuple

try:
from pandas import BooleanDtype, StringDtype
Expand Down Expand Up @@ -66,17 +66,17 @@
from databricks import koalas as ks # For running doctests and reference resolution in PyCharm.
from databricks.koalas.typedef.string_typehints import resolve_string_type_hint

T = typing.TypeVar("T")
T = TypeVar("T")

Scalar = typing.Union[
Scalar = Union[
int, float, bool, str, bytes, decimal.Decimal, datetime.date, datetime.datetime, None
]

Dtype = typing.Union[np.dtype, ExtensionDtype]
Dtype = Union[np.dtype, ExtensionDtype]


# A column of data, with the data type.
class SeriesType(typing.Generic[T]):
class SeriesType(Generic[T]):
def __init__(self, tpe):
self.tpe = tpe # type: types.DataType

Expand Down Expand Up @@ -122,9 +122,7 @@ class NameTypeHolder(object):
tpe = None


def as_spark_type(
tpe: typing.Union[str, type, Dtype], *, raise_error: bool = True
) -> types.DataType:
def as_spark_type(tpe: Union[str, type, Dtype], *, raise_error: bool = True) -> types.DataType:
"""
Given a Python type, returns the equivalent spark type.
Accepts:
Expand All @@ -134,9 +132,10 @@ def as_spark_type(
- dictionaries of field_name -> type
- Python3's typing system
"""
# TODO: Add "boolean" and "string" types.
if isinstance(tpe, np.dtype) and tpe == np.dtype("object"):
pass
# ArrayType
if tpe in (np.ndarray,):
elif tpe in (np.ndarray,):
return types.ArrayType(types.StringType())
elif hasattr(tpe, "__origin__") and issubclass(tpe.__origin__, list): # type: ignore
element_type = as_spark_type(tpe.__args__[0], raise_error=raise_error) # type: ignore
Expand Down Expand Up @@ -244,7 +243,15 @@ def spark_type_to_pandas_dtype(
return Float64Dtype()

if isinstance(
spark_type, (types.DateType, types.NullType, types.StructType, types.UserDefinedType)
spark_type,
(
types.DateType,
types.NullType,
types.ArrayType,
types.MapType,
types.StructType,
types.UserDefinedType,
),
):
return np.dtype("object")
elif isinstance(spark_type, types.TimestampType):
Expand All @@ -253,6 +260,45 @@ def spark_type_to_pandas_dtype(
return np.dtype(to_arrow_type(spark_type).to_pandas_dtype())


def koalas_dtype(tpe) -> Tuple[Dtype, types.DataType]:
"""
Convert input into a pandas only dtype object or a numpy dtype object,
and its corresponding Spark DataType.
Parameters
----------
tpe : object to be converted
Returns
-------
tuple of np.dtype or a pandas dtype, and Spark DataType
Raises
------
TypeError if not a dtype
Examples
--------
>>> koalas_dtype(int)
(dtype('int64'), LongType)
>>> koalas_dtype(str)
(dtype('<U'), StringType)
>>> koalas_dtype(datetime.date)
(dtype('O'), DateType)
>>> koalas_dtype(datetime.datetime)
(dtype('<M8[ns]'), TimestampType)
>>> koalas_dtype(List[bool])
(dtype('O'), ArrayType(BooleanType,true))
"""
try:
dtype = pandas_dtype(tpe)
spark_type = as_spark_type(dtype)
except TypeError:
spark_type = as_spark_type(tpe)
dtype = spark_type_to_pandas_dtype(spark_type)
return dtype, spark_type


def infer_pd_series_spark_type(pser: pd.Series, dtype: Dtype) -> types.DataType:
"""Infer Spark DataType from pandas Series dtype.
Expand All @@ -274,7 +320,7 @@ def infer_pd_series_spark_type(pser: pd.Series, dtype: Dtype) -> types.DataType:
return as_spark_type(dtype)


def infer_return_type(f) -> typing.Union[SeriesType, DataFrameType, ScalarType, UnknownType]:
def infer_return_type(f) -> Union[SeriesType, DataFrameType, ScalarType, UnknownType]:
"""
>>> def func() -> int:
... pass
Expand Down

0 comments on commit ed0bd49

Please sign in to comment.