From 5446ec7031d247288776b95a89d0e811e5bcb470 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 29 Mar 2021 16:30:55 -0700 Subject: [PATCH] Fix DataFrame.apply to support additional dtypes. --- databricks/koalas/frame.py | 8 ++- databricks/koalas/tests/test_categorical.py | 55 ++++++++++++++++----- 2 files changed, 50 insertions(+), 13 deletions(-) diff --git a/databricks/koalas/frame.py b/databricks/koalas/frame.py index ffe3ecfcc..324a5bab5 100644 --- a/databricks/koalas/frame.py +++ b/databricks/koalas/frame.py @@ -2588,6 +2588,7 @@ def apply_func(pdf): self_applied.columns, [return_schema] * len(self_applied.columns) ) return_schema = StructType([StructField(c, t) for c, t in fields_types]) + data_dtypes = [cast(SeriesType, return_type).dtype] * len(self_applied.columns) elif require_column_axis: if axis != 1: raise TypeError( @@ -2596,11 +2597,13 @@ def apply_func(pdf): "was %s" % return_sig ) return_schema = cast(DataFrameType, return_type).spark_type + data_dtypes = cast(DataFrameType, return_type).dtypes else: # any axis is fine. should_return_series = True return_schema = cast(ScalarType, return_type).spark_type return_schema = StructType([StructField(SPARK_DEFAULT_SERIES_NAME, return_schema)]) + data_dtypes = [cast(ScalarType, return_type).dtype] column_labels = [None] if should_use_map_in_pandas: @@ -2621,7 +2624,10 @@ def apply_func(pdf): # Otherwise, it loses index. internal = InternalFrame( - spark_frame=sdf, index_spark_columns=None, column_labels=column_labels + spark_frame=sdf, + index_spark_columns=None, + column_labels=column_labels, + data_dtypes=data_dtypes, ) result = DataFrame(internal) # type: "DataFrame" diff --git a/databricks/koalas/tests/test_categorical.py b/databricks/koalas/tests/test_categorical.py index 1edd21ea9..4ca03ad90 100644 --- a/databricks/koalas/tests/test_categorical.py +++ b/databricks/koalas/tests/test_categorical.py @@ -23,15 +23,27 @@ class CategoricalTest(ReusedSQLTestCase, TestUtils): - def test_categorical_frame(self): - pdf = pd.DataFrame( + @property + def pdf(self): + return pd.DataFrame( { "a": pd.Categorical([1, 2, 3, 1, 2, 3]), - "b": pd.Categorical(["a", "b", "c", "a", "b", "c"], categories=["c", "b", "a"]), + "b": pd.Categorical( + ["b", "a", "c", "c", "b", "a"], categories=["c", "b", "d", "a"] + ), }, - index=pd.Categorical([10, 20, 30, 20, 30, 10], categories=[30, 10, 20], ordered=True), ) - kdf = ks.from_pandas(pdf) + + @property + def kdf(self): + return ks.from_pandas(self.pdf) + + @property + def df_pair(self): + return (self.pdf, self.kdf) + + def test_categorical_frame(self): + pdf, kdf = self.df_pair self.assert_eq(kdf, pdf) self.assert_eq(kdf.a, pdf.a) @@ -94,17 +106,36 @@ def test_factorize(self): self.assert_eq(kcodes.tolist(), pcodes.tolist()) self.assert_eq(kuniques, puniques) - def test_groupby_apply(self): + def test_frame_apply(self): + pdf, kdf = self.df_pair + + self.assert_eq(kdf.apply(lambda x: x).sort_index(), pdf.apply(lambda x: x).sort_index()) + self.assert_eq( + kdf.apply(lambda x: x, axis=1).sort_index(), pdf.apply(lambda x: x, axis=1).sort_index() + ) + + def test_frame_apply_without_shortcut(self): + with ks.option_context("compute.shortcut_limit", 0): + self.test_frame_apply() + pdf = pd.DataFrame( - { - "a": pd.Categorical([1, 2, 3, 1, 2, 3]), - "b": pd.Categorical( - ["b", "a", "c", "c", "b", "a"], categories=["c", "b", "d", "a"] - ), - }, + {"a": ["a", "b", "c", "a", "b", "c"], "b": ["b", "a", "c", "c", "b", "a"]} ) kdf = ks.from_pandas(pdf) + dtype = CategoricalDtype(categories=["a", "b", "c"]) + + def categorize(ser) -> ks.Series[dtype]: + return ser.astype(dtype) + + self.assert_eq( + kdf.apply(categorize).sort_values(["a", "b"]).reset_index(drop=True), + pdf.apply(categorize).sort_values(["a", "b"]).reset_index(drop=True), + ) + + def test_groupby_apply(self): + pdf, kdf = self.df_pair + self.assert_eq( kdf.groupby("a").apply(lambda df: df).sort_index(), pdf.groupby("a").apply(lambda df: df).sort_index(),