Skip to content

Commit 0d2cef0

Browse files
authored
Support ExtensionDtypes as type arguments. (#2106)
Support `ExtensionDtype`s as type arguments by reusing `NameTypeHolder` for DataFrame's type annotation. Also support to infer Spark DataType from the return type annotation with `ExtensionDtype`s. Before: ```py >>> ks.Series[pd.Int32Dtype()] Traceback (most recent call last): ... TypeError: Parameters to generic types must be types. Got Int32Dtype(). ``` After: ```py >>> ks.Series[pd.Int32Dtype()] databricks.koalas.typedef.typehints.SeriesType[databricks.koalas.series.NameType] >>> def a() -> ks.Series[pd.Int32Dtype()]: ... pass ... >>> infer_return_type(a) SeriesType[IntegerType] ```
1 parent a595bfb commit 0d2cef0

File tree

4 files changed

+128
-30
lines changed

4 files changed

+128
-30
lines changed

databricks/koalas/frame.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
import numpy as np
4949
import pandas as pd
5050
from pandas.api.types import is_list_like, is_dict_like, is_scalar
51+
from pandas.api.extensions import ExtensionDtype
5152

5253
if TYPE_CHECKING:
5354
from pandas.io.formats.style import Styler
@@ -333,6 +334,12 @@
333334

334335

335336
def _create_tuple_for_frame_type(params):
337+
"""
338+
This is a workaround to support variadic generic in DataFrame.
339+
340+
See https://github.com/python/typing/issues/193
341+
we always wraps the given type hints by a tuple to mimic the variadic generic.
342+
"""
336343
from databricks.koalas.typedef import NameTypeHolder
337344

338345
if isinstance(params, zip):
@@ -365,8 +372,16 @@ def _create_tuple_for_frame_type(params):
365372

366373
if not isinstance(params, Iterable):
367374
params = [params]
368-
params = [param.type if isinstance(param, np.dtype) else param for param in params]
369-
return Tuple[tuple(params)]
375+
376+
new_params = []
377+
for param in params:
378+
if isinstance(param, ExtensionDtype):
379+
new_class = type("NameType", (NameTypeHolder,), {})
380+
new_class.tpe = param
381+
new_params.append(new_class)
382+
else:
383+
new_params.append(param.type if isinstance(param, np.dtype) else param)
384+
return Tuple[tuple(new_params)]
370385

371386

372387
if (3, 5) <= sys.version_info < (3, 7):

databricks/koalas/series.py

+35-3
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from pandas.core.accessor import CachedAccessor
3232
from pandas.io.formats.printing import pprint_thing
3333
from pandas.api.types import is_list_like, is_hashable
34+
from pandas.api.extensions import ExtensionDtype
3435
import pyspark
3536
from pyspark import sql as spark
3637
from pyspark.sql import functions as F, Column
@@ -86,9 +87,9 @@
8687
from databricks.koalas.typedef import (
8788
infer_return_type,
8889
spark_type_to_pandas_dtype,
89-
SeriesType,
9090
ScalarType,
9191
Scalar,
92+
SeriesType,
9293
)
9394

9495

@@ -320,6 +321,32 @@
320321
str_type = str
321322

322323

324+
def _create_type_for_series_type(param):
325+
from databricks.koalas.typedef import NameTypeHolder
326+
327+
if isinstance(param, ExtensionDtype):
328+
new_class = type("NameType", (NameTypeHolder,), {})
329+
new_class.tpe = param
330+
else:
331+
new_class = param.type if isinstance(param, np.dtype) else param
332+
333+
return SeriesType[new_class]
334+
335+
336+
if (3, 5) <= sys.version_info < (3, 7):
337+
from typing import GenericMeta # type: ignore
338+
339+
old_getitem = GenericMeta.__getitem__ # type: ignore
340+
341+
def new_getitem(self, params):
342+
if hasattr(self, "is_series"):
343+
return old_getitem(self, _create_type_for_series_type(params))
344+
else:
345+
return old_getitem(self, params)
346+
347+
GenericMeta.__getitem__ = new_getitem # type: ignore
348+
349+
323350
class Series(Frame, IndexOpsMixin, Generic[T]):
324351
"""
325352
Koalas Series that corresponds to pandas Series logically. This holds Spark Column
@@ -5978,8 +6005,13 @@ def __iter__(self):
59786005

59796006
if sys.version_info >= (3, 7):
59806007
# In order to support the type hints such as Series[...]. See DataFrame.__class_getitem__.
5981-
def __class_getitem__(cls, tpe):
5982-
return SeriesType[tpe]
6008+
def __class_getitem__(cls, params):
6009+
return _create_type_for_series_type(params)
6010+
6011+
elif (3, 5) <= sys.version_info < (3, 7):
6012+
# The implementation is in its metaclass so this flag is needed to distinguish
6013+
# Koalas Series.
6014+
is_series = None
59836015

59846016

59856017
def unpack_scalar(sdf):

databricks/koalas/tests/test_typedef.py

+40
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import pandas
2323
import pandas as pd
24+
from pandas.api.types import CategoricalDtype
2425
import numpy as np
2526
from pyspark.sql.types import (
2627
ArrayType,
@@ -103,6 +104,19 @@ def func() -> pd.DataFrame[pdf.dtypes]: # type: ignore
103104
expected = StructType([StructField("c0", LongType()), StructField("c1", LongType())])
104105
self.assertEqual(infer_return_type(func).tpe, expected)
105106

107+
pdf = pd.DataFrame({"a": [1, 2, 3], "b": pd.Categorical(["a", "b", "c"])})
108+
109+
def func() -> pd.Series[pdf.b.dtype]: # type: ignore
110+
pass
111+
112+
self.assertEqual(infer_return_type(func).tpe, LongType())
113+
114+
def func() -> pd.DataFrame[pdf.dtypes]: # type: ignore
115+
pass
116+
117+
expected = StructType([StructField("c0", LongType()), StructField("c1", LongType())])
118+
self.assertEqual(infer_return_type(func).tpe, expected)
119+
106120
def test_if_pandas_implements_class_getitem(self):
107121
# the current type hint implementation of pandas DataFrame assumes pandas doesn't
108122
# implement '__class_getitem__'. This test case is to make sure pandas
@@ -145,6 +159,14 @@ def func() -> pd.DataFrame[zip(pdf.columns, pdf.dtypes)]:
145159
)
146160
self.assertEqual(infer_return_type(func).tpe, expected)
147161

162+
pdf = pd.DataFrame({"a": [1, 2, 3], "b": pd.Categorical(["a", "b", "c"])})
163+
164+
def func() -> pd.DataFrame[zip(pdf.columns, pdf.dtypes)]:
165+
pass
166+
167+
expected = StructType([StructField("a", LongType()), StructField("b", LongType())])
168+
self.assertEqual(infer_return_type(func).tpe, expected)
169+
148170
@unittest.skipIf(
149171
sys.version_info < (3, 7),
150172
"Type inference from pandas instances is supported with Python 3.7+",
@@ -188,6 +210,14 @@ def f() -> pd.DataFrame[pdf.dtypes]: # type: ignore
188210

189211
self.assertRaisesRegex(TypeError, "object.*not understood", try_infer_return_type)
190212

213+
def try_infer_return_type():
214+
def f() -> pd.Series[pdf.a.dtype]: # type: ignore
215+
pass
216+
217+
infer_return_type(f)
218+
219+
self.assertRaisesRegex(TypeError, "object.*not understood", try_infer_return_type)
220+
191221
def test_infer_schema_with_names_negative(self):
192222
def try_infer_return_type():
193223
def f() -> 'ks.DataFrame["a" : np.float : 1, "b":str:2]': # noqa: F821
@@ -227,6 +257,14 @@ def f() -> ks.DataFrame[pdf.dtypes]: # type: ignore
227257

228258
self.assertRaisesRegex(TypeError, "object.*not understood", try_infer_return_type)
229259

260+
def try_infer_return_type():
261+
def f() -> ks.Series[pdf.a.dtype]: # type: ignore
262+
pass
263+
264+
infer_return_type(f)
265+
266+
self.assertRaisesRegex(TypeError, "object.*not understood", try_infer_return_type)
267+
230268
def test_as_spark_type(self):
231269
type_mapper = {
232270
# binary
@@ -286,6 +324,8 @@ def test_as_spark_type(self):
286324
List[np.unicode_]: ArrayType(StringType()),
287325
List[datetime.datetime]: ArrayType(TimestampType()),
288326
List[np.datetime64]: ArrayType(TimestampType()),
327+
# CategoricalDtype
328+
CategoricalDtype(categories=["a", "b", "c"]): LongType(),
289329
}
290330

291331
for numpy_or_python_type, spark_type in type_mapper.items():

databricks/koalas/typedef/typehints.py

+36-25
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import typing
2121
import datetime
2222
import decimal
23-
from inspect import getfullargspec, isclass
23+
from inspect import getfullargspec
2424

2525
import numpy as np
2626
import pandas as pd
@@ -85,18 +85,15 @@ def __repr__(self):
8585

8686

8787
class DataFrameType(object):
88-
def __init__(self, tpe, names=None):
88+
def __init__(self, tpe, names):
8989
from databricks.koalas.utils import name_like_string
9090

91-
if names is None:
92-
# Default names `c0, c1, ... cn`.
93-
self.tpe = types.StructType(
94-
[types.StructField("c%s" % i, tpe[i]) for i in range(len(tpe))]
95-
) # type: types.StructType
96-
else:
97-
self.tpe = types.StructType(
98-
[types.StructField(name_like_string(n), t) for n, t in zip(names, tpe)]
99-
) # type: types.StructType
91+
self.tpe = types.StructType(
92+
[
93+
types.StructField(name_like_string(n) if n is not None else ("c%s" % i), t)
94+
for i, (n, t) in enumerate(zip(names, tpe))
95+
]
96+
) # type: types.StructType
10097

10198
def __repr__(self):
10299
return "DataFrameType[{}]".format(self.tpe)
@@ -346,6 +343,22 @@ def infer_return_type(f) -> typing.Union[SeriesType, DataFrameType, ScalarType,
346343
... pass
347344
>>> infer_return_type(func).tpe
348345
StructType(List(StructField((x, a),LongType,true),StructField((y, b),LongType,true)))
346+
347+
>>> pdf = pd.DataFrame({"a": [1, 2, 3], "b": pd.Categorical([3, 4, 5])})
348+
>>> def func() -> ks.DataFrame[pdf.dtypes]:
349+
... pass
350+
>>> infer_return_type(func).tpe
351+
StructType(List(StructField(c0,LongType,true),StructField(c1,LongType,true)))
352+
353+
>>> def func() -> ks.DataFrame[zip(pdf.columns, pdf.dtypes)]:
354+
... pass
355+
>>> infer_return_type(func).tpe
356+
StructType(List(StructField(a,LongType,true),StructField(b,LongType,true)))
357+
358+
>>> def func() -> ks.Series[pdf.b.dtype]:
359+
... pass
360+
>>> infer_return_type(func).tpe
361+
LongType
349362
"""
350363
# We should re-import to make sure the class 'SeriesType' is not treated as a class
351364
# within this module locally. See Series.__class_getitem__ which imports this class
@@ -357,17 +370,19 @@ def infer_return_type(f) -> typing.Union[SeriesType, DataFrameType, ScalarType,
357370
if isinstance(tpe, str):
358371
# This type hint can happen when given hints are string to avoid forward reference.
359372
tpe = resolve_string_type_hint(tpe)
373+
360374
if hasattr(tpe, "__origin__") and (
361-
issubclass(tpe.__origin__, SeriesType) or tpe.__origin__ == ks.Series
375+
tpe.__origin__ == ks.DataFrame or tpe.__origin__ == ks.Series
362376
):
363-
# TODO: remove "tpe.__origin__ == ks.Series" when we drop Python 3.5 and 3.6.
364-
inner = as_spark_type(tpe.__args__[0])
365-
return SeriesType(inner)
377+
# When Python version is lower then 3.7. Unwrap it to a Tuple/SeriesType type hints.
378+
tpe = tpe.__args__[0]
366379

367-
if hasattr(tpe, "__origin__") and tpe.__origin__ == ks.DataFrame:
368-
# When Python version is lower then 3.7. Unwrap it to a Tuple type
369-
# hints.
380+
if hasattr(tpe, "__origin__") and issubclass(tpe.__origin__, SeriesType):
370381
tpe = tpe.__args__[0]
382+
if issubclass(tpe, NameTypeHolder):
383+
tpe = tpe.tpe
384+
inner = as_spark_type(tpe)
385+
return SeriesType(inner)
371386

372387
# Note that, DataFrame type hints will create a Tuple.
373388
# Python 3.6 has `__name__`. Python 3.7 and 3.8 have `_name`.
@@ -381,13 +396,9 @@ def infer_return_type(f) -> typing.Union[SeriesType, DataFrameType, ScalarType,
381396
parameters = getattr(tuple_type, "__tuple_params__")
382397
else:
383398
parameters = getattr(tuple_type, "__args__")
384-
if len(parameters) > 0 and all(
385-
isclass(p) and issubclass(p, NameTypeHolder) for p in parameters
386-
):
387-
names = [p.name for p in parameters if issubclass(p, NameTypeHolder)]
388-
types = [p.tpe for p in parameters if issubclass(p, NameTypeHolder)]
389-
return DataFrameType([as_spark_type(t) for t in types], names)
390-
return DataFrameType([as_spark_type(t) for t in parameters])
399+
names = [p.name if issubclass(p, NameTypeHolder) else None for p in parameters]
400+
types = [p.tpe if issubclass(p, NameTypeHolder) else p for p in parameters]
401+
return DataFrameType([as_spark_type(t) for t in types], names)
391402
inner = as_spark_type(tpe)
392403
if inner is None:
393404
return UnknownType(tpe)

0 commit comments

Comments
 (0)