Skip to content

Commit

Permalink
Type annotations to Koalas accessors and Spark accessors (#1902)
Browse files Browse the repository at this point in the history
  • Loading branch information
xinrong-meng authored Nov 11, 2020
1 parent b4cb45e commit 44d45f2
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 17 deletions.
16 changes: 8 additions & 8 deletions databricks/koalas/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""
import inspect
from distutils.version import LooseVersion
from typing import Any, Tuple, Union, TYPE_CHECKING
from typing import Any, Tuple, Union, TYPE_CHECKING, cast
import types

import numpy as np # noqa: F401
Expand Down Expand Up @@ -185,7 +185,7 @@ def attach_id_column(self, id_type: str, column: Union[Any, Tuple]) -> "DataFram
).resolved_copy
)

def apply_batch(self, func, args=(), **kwds):
def apply_batch(self, func, args=(), **kwds) -> "DataFrame":
"""
Apply a function that takes pandas DataFrame and outputs pandas DataFrame. The pandas
DataFrame given to the function is of a batch used internally.
Expand Down Expand Up @@ -330,7 +330,7 @@ def apply_batch(self, func, args=(), **kwds):
original_func = func
func = lambda o: original_func(o, *args, **kwds)

self_applied = DataFrame(self._kdf._internal.resolved_copy)
self_applied = DataFrame(self._kdf._internal.resolved_copy) # type: DataFrame

if should_infer_schema:
# Here we execute with the first 1000 to get the return type.
Expand All @@ -343,7 +343,7 @@ def apply_batch(self, func, args=(), **kwds):
"The given function should return a frame; however, "
"the return type was %s." % type(applied)
)
kdf = ks.DataFrame(applied)
kdf = ks.DataFrame(applied) # type: DataFrame
if len(pdf) <= limit:
return kdf

Expand Down Expand Up @@ -389,7 +389,7 @@ def apply_batch(self, func, args=(), **kwds):

return DataFrame(internal)

def transform_batch(self, func, *args, **kwargs):
def transform_batch(self, func, *args, **kwargs) -> Union["DataFrame", "Series"]:
"""
Transform chunks with a function that takes pandas DataFrame and outputs pandas DataFrame.
The pandas DataFrame given to the function is of a batch used internally. The length of
Expand Down Expand Up @@ -450,7 +450,7 @@ def transform_batch(self, func, *args, **kwargs):
Returns
-------
DataFrame
DataFrame or Series
See Also
--------
Expand Down Expand Up @@ -594,12 +594,12 @@ def pandas_frame_func(f):
if len(pdf) <= limit:
# only do the short cut when it returns a frame to avoid
# operations on different dataframes in case of series.
return kdf
return cast(ks.DataFrame, kdf)

# Force nullability.
return_schema = as_nullable_spark_type(kdf._internal.to_internal_spark_frame.schema)

self_applied = DataFrame(self._kdf._internal.resolved_copy)
self_applied = DataFrame(self._kdf._internal.resolved_copy) # type: DataFrame

output_func = GroupBy._make_pandas_df_builder_func(
self_applied, func, return_schema, retain_index=True
Expand Down
38 changes: 29 additions & 9 deletions databricks/koalas/spark/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
but Spark has it.
"""
from distutils.version import LooseVersion
from typing import TYPE_CHECKING, Optional, Union, List
from typing import TYPE_CHECKING, Optional, Union, List, cast

import pyspark
from pyspark import StorageLevel
Expand Down Expand Up @@ -59,7 +59,7 @@ def column(self) -> Column:
"""
return self._data._internal.spark_column_for(self._data._column_label)

def transform(self, func):
def transform(self, func) -> Union["ks.Series", "ks.Index"]:
"""
Applies a function that takes and returns a Spark column. It allows to natively
apply a Spark function and column APIs with the Spark column internally used
Expand Down Expand Up @@ -126,7 +126,7 @@ def transform(self, func):

class SparkSeriesMethods(SparkIndexOpsMethods):
def transform(self, func) -> "ks.Series":
return super().transform(func)
return cast("ks.Series", super().transform(func))

transform.__doc__ = SparkIndexOpsMethods.transform.__doc__

Expand Down Expand Up @@ -252,7 +252,7 @@ def analyzed(self) -> "ks.Series":

class SparkIndexMethods(SparkIndexOpsMethods):
def transform(self, func) -> "ks.Index":
return super().transform(func)
return cast("ks.Index", super().transform(func))

transform.__doc__ = SparkIndexOpsMethods.transform.__doc__

Expand Down Expand Up @@ -295,7 +295,7 @@ def schema(self, index_col: Optional[Union[str, List[str]]] = None) -> StructTyp
"""
return self.frame(index_col).schema

def print_schema(self, index_col: Optional[Union[str, List[str]]] = None):
def print_schema(self, index_col: Optional[Union[str, List[str]]] = None) -> None:
"""
Prints out the underlying Spark schema in the tree format.
Expand All @@ -305,6 +305,10 @@ def print_schema(self, index_col: Optional[Union[str, List[str]]] = None):
Column names to be used in Spark to represent Koalas' index. The index name
in Koalas is ignored. By default, the index is always lost.
Returns
-------
None
Examples
--------
>>> df = ks.DataFrame({'a': list('abc'),
Expand Down Expand Up @@ -634,7 +638,7 @@ def to_table(
partition_cols: Optional[Union[str, List[str]]] = None,
index_col: Optional[Union[str, List[str]]] = None,
**options
):
) -> None:
"""
Write the DataFrame into a Spark table. :meth:`DataFrame.spark.to_table`
is an alias of :meth:`DataFrame.to_table`.
Expand Down Expand Up @@ -669,6 +673,10 @@ def to_table(
options
Additional options passed directly to Spark.
Returns
-------
None
See Also
--------
read_table
Expand Down Expand Up @@ -705,7 +713,7 @@ def to_spark_io(
partition_cols: Optional[Union[str, List[str]]] = None,
index_col: Optional[Union[str, List[str]]] = None,
**options
):
) -> None:
"""Write the DataFrame out to a Spark data source. :meth:`DataFrame.spark.to_spark_io`
is an alias of :meth:`DataFrame.to_spark_io`.
Expand Down Expand Up @@ -736,6 +744,10 @@ def to_spark_io(
options : dict
All other options passed directly into Spark's data source.
Returns
-------
None
See Also
--------
read_spark_io
Expand Down Expand Up @@ -766,7 +778,7 @@ def to_spark_io(
path=path, format=format, mode=mode, partitionBy=partition_cols, **options
)

def explain(self, extended: Optional[bool] = None, mode: Optional[str] = None):
def explain(self, extended: Optional[bool] = None, mode: Optional[str] = None) -> None:
"""
Prints the underlying (logical and physical) Spark plans to the console for debugging
purpose.
Expand All @@ -778,6 +790,10 @@ def explain(self, extended: Optional[bool] = None, mode: Optional[str] = None):
mode : string, default ``None``.
The expected output format of plans.
Returns
-------
None
Examples
--------
>>> df = ks.DataFrame({'id': range(10)})
Expand Down Expand Up @@ -1164,11 +1180,15 @@ def storage_level(self) -> StorageLevel:
"""
return self._kdf._cached.storageLevel

def unpersist(self):
def unpersist(self) -> None:
"""
The `unpersist` function is used to uncache the Koalas DataFrame when it
is not used with `with` statement.
Returns
-------
None
Examples
--------
>>> df = ks.DataFrame([(.2, .3), (.0, .6), (.6, .0), (.2, .1)],
Expand Down

0 comments on commit 44d45f2

Please sign in to comment.