From db4c122dcd293a9e82c2dc394ae22f6d4c2837a6 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Tue, 23 Mar 2021 15:28:08 -0700 Subject: [PATCH] Allow multi-index column names for inferring return type schema with names. --- databricks/koalas/tests/test_typedef.py | 10 ++++++++++ databricks/koalas/typedef/typehints.py | 10 +++++++++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/databricks/koalas/tests/test_typedef.py b/databricks/koalas/tests/test_typedef.py index 41fa95ba9d..a7e4d027a8 100644 --- a/databricks/koalas/tests/test_typedef.py +++ b/databricks/koalas/tests/test_typedef.py @@ -135,6 +135,16 @@ def func() -> pd.DataFrame[zip(pdf.columns, pdf.dtypes)]: expected = StructType([StructField("a", LongType()), StructField("b", LongType())]) self.assertEqual(infer_return_type(func).tpe, expected) + pdf = pd.DataFrame({("x", "a"): [1, 2, 3], ("y", "b"): [3, 4, 5]}) + + def func() -> pd.DataFrame[zip(pdf.columns, pdf.dtypes)]: + pass + + expected = StructType( + [StructField("(x, a)", LongType()), StructField("(y, b)", LongType())] + ) + self.assertEqual(infer_return_type(func).tpe, expected) + @unittest.skipIf( sys.version_info < (3, 7), "Type inference from pandas instances is supported with Python 3.7+", diff --git a/databricks/koalas/typedef/typehints.py b/databricks/koalas/typedef/typehints.py index 1811456df6..f9d982ef51 100644 --- a/databricks/koalas/typedef/typehints.py +++ b/databricks/koalas/typedef/typehints.py @@ -86,6 +86,8 @@ def __repr__(self): class DataFrameType(object): def __init__(self, tpe, names=None): + from databricks.koalas.utils import name_like_string + if names is None: # Default names `c0, c1, ... cn`. self.tpe = types.StructType( @@ -93,7 +95,7 @@ def __init__(self, tpe, names=None): ) # type: types.StructType else: self.tpe = types.StructType( - [types.StructField(n, t) for n, t in zip(names, tpe)] + [types.StructField(name_like_string(n), t) for n, t in zip(names, tpe)] ) # type: types.StructType def __repr__(self): @@ -338,6 +340,12 @@ def infer_return_type(f) -> typing.Union[SeriesType, DataFrameType, ScalarType, ... pass >>> infer_return_type(func).tpe StructType(List(StructField(a,LongType,true),StructField(b,LongType,true))) + + >>> pdf = pd.DataFrame({("x", "a"): [1, 2, 3], ("y", "b"): [3, 4, 5]}) + >>> def func() -> ks.DataFrame[zip(pdf.columns, pdf.dtypes)]: + ... pass + >>> infer_return_type(func).tpe + StructType(List(StructField((x, a),LongType,true),StructField((y, b),LongType,true))) """ # We should re-import to make sure the class 'SeriesType' is not treated as a class # within this module locally. See Series.__class_getitem__ which imports this class