Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix DataFrame.koalas.transform_batch to support additional dtypes. #2132

Merged
merged 2 commits into from
Apr 1, 2021

Conversation

ueshin
Copy link
Collaborator

@ueshin ueshin commented Apr 1, 2021

Fix DataFrame.koalas.transform_batch to support additional dtypes.

After this, additional dtypes can be specified in the return type annotation of the UDFs for DataFrame.koalas.transform_batch.

>>> kdf = ks.DataFrame(
...     {"a": ["a", "b", "c", "a", "b", "c"], "b": ["b", "a", "c", "c", "b", "a"]}
... )
>>> dtype = pd.CategoricalDtype(categories=["a", "b", "c", "d"])
>>> def to_category(pdf) -> ks.DataFrame["a":dtype, "b":dtype]:
...   return pdf.astype(dtype)
...
>>> applied = kdf.koalas.transform_batch(to_category)
>>> applied
   a  b
0  a  b
1  b  a
2  c  c
3  a  c
4  b  b
5  c  a
>>> applied.dtypes
a    category
b    category
dtype: object

@ueshin ueshin requested a review from xinrong-meng April 1, 2021 00:12
def pandas_extract(pdf, name):
# This is for output to work around a DataFrame for struct
# from Spark 3.0. See SPARK-23836
return pdf[name]

def pandas_series_func(f):
def pandas_series_func(f, by_pass):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it called by_pass? Would you please help me understand?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It uses some new Spark APIs to "by pass" a workaround:

You can see:

  • if should_by_pass:
    pudf = pandas_udf(
    output_func, returnType=return_schema, functionType=PandasUDFType.SCALAR
    )
    temp_struct_column = verify_temp_column_name(
    self_applied._internal.spark_frame, "__temp_struct__"
    )
    applied = pudf(F.struct(*columns)).alias(temp_struct_column)
    sdf = self_applied._internal.spark_frame.select(applied)
    sdf = sdf.selectExpr("%s.*" % temp_struct_column)
    else:
    applied = []
    for field in return_schema.fields:
    applied.append(
    pandas_udf(
    pandas_frame_func(output_func, field.name),
    returnType=field.dataType,
    functionType=PandasUDFType.SCALAR,
    )(*columns).alias(field.name)
    )
    sdf = self_applied._internal.spark_frame.select(*applied)
  • if should_by_pass:
    pudf = pandas_udf(
    output_func, returnType=return_schema, functionType=PandasUDFType.SCALAR
    )
    temp_struct_column = verify_temp_column_name(
    self_applied._internal.spark_frame, "__temp_struct__"
    )
    applied = pudf(F.struct(*columns)).alias(temp_struct_column)
    sdf = self_applied._internal.spark_frame.select(applied)
    sdf = sdf.selectExpr("%s.*" % temp_struct_column)
    else:
    applied = []
    for field in return_schema.fields:
    applied.append(
    pandas_udf(
    pandas_frame_func(output_func, field.name),
    returnType=field.dataType,
    functionType=PandasUDFType.SCALAR,
    )(*columns).alias(field.name)
    )
    sdf = self_applied._internal.spark_frame.select(*applied)

@xinrong-meng
Copy link
Contributor

Looks great, thank you!

@ueshin
Copy link
Collaborator Author

ueshin commented Apr 1, 2021

Thanks! merging.

@ueshin ueshin merged commit d7f6e88 into databricks:master Apr 1, 2021
@ueshin ueshin deleted the dataframe_transform_batch branch April 1, 2021 02:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants