diff --git a/.gitignore b/.gitignore index 6e7f97b960..fe97af3e18 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ # mpeltonen/sbt-idea plugin .idea_modules/ .idea/ +.vscode/ *.iml diff --git a/databricks/koalas/extensions.py b/databricks/koalas/extensions.py new file mode 100644 index 0000000000..ade7b29651 --- /dev/null +++ b/databricks/koalas/extensions.py @@ -0,0 +1,354 @@ +# +# Copyright (C) 2019 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import warnings +from functools import wraps + + +class CachedAccessor: + """ + Custom property-like object. + + A descriptor for caching accessors: + + Parameters + ---------- + name : str + Namespace that accessor's methods, properties, etc will be accessed under, e.g. "foo" for a + dataframe accessor yields the accessor ``df.foo`` + accessor: cls + Class with the extension methods. + + Notes + ----- + For accessor, the class's __init__ method assumes that you are registering an accessor for one + of ``Series``, ``DataFrame``, or ``Index``. + + This object is not meant to be instantiated directly. Instead, use register_dataframe_accessor, + register_series_accessor, or register_index_accessor. + + The Koalas accessor is modified based on pandas.core.accessor. + """ + + def __init__(self, name, accessor): + self._name = name + self._accessor = accessor + + def __get__(self, obj, cls): + if obj is None: + return self._accessor + accessor_obj = self._accessor(obj) + setattr(obj, self._name, accessor_obj) + return accessor_obj + + +def _register_accessor(name, cls): + """ + Register a custom accessor on {klass} objects. + + Parameters + ---------- + name : str + Name under which the accessor should be registered. A warning is issued if this name + conflicts with a preexisting attribute. + + Returns + ------- + callable + A class decorator. + + See Also + -------- + register_dataframe_accessor: Register a custom accessor on DataFrame objects + register_series_accessor: Register a custom accessor on Series objects + register_index_accessor: Register a custom accessor on Index objects + + Notes + ----- + When accessed, your accessor will be initialiazed with the Koalas object the user is interacting + with. The code signature must be: + + .. code-block:: python + + def __init__(self, koalas_obj): + # constructor logic + ... + + In the pandas API, if data passed to your accessor has an incorrect dtype, it's recommended to + raise an ``AttributeError`` for consistency purposes. In Koalas, ``ValueError`` is more + frequently used to annotate when a value's datatype is unexpected for a given method/function. + + Ultimately, you can structure this however you like, but Koalas would likely do something like + this: + + >>> ks.Series(['a', 'b']).dt + ... + Traceback (most recent call last): + ... + ValueError: Cannot call DatetimeMethods on type StringType + + Note: This function is not meant to be used directly - instead, use register_dataframe_accessor, + register_series_accessor, or register_index_accessor. + """ + + def decorator(accessor): + if hasattr(cls, name): + msg = "registration of accessor {0} under name {1} for type {2} is overriding \ + a preexisting attribute with the same name.".format( + accessor, name, cls + ) + + warnings.warn( + msg, UserWarning, stacklevel=2, + ) + setattr(cls, name, CachedAccessor(name, accessor)) + return accessor + + return decorator + + +def register_dataframe_accessor(name): + """ + Register a custom accessor with a DataFrame + + Parameters + ---------- + name : str + name used when calling the accessor after its registered + + Returns + ------- + callable + A class decorator. + + See Also + -------- + register_series_accessor: Register a custom accessor on Series objects + register_index_accessor: Register a custom accessor on Index objects + + Notes + ----- + When accessed, your accessor will be initialiazed with the Koalas object the user is interacting + with. The accessor's init method should always ingest the object being accessed. See the + examples for the init signature. + + In the pandas API, if data passed to your accessor has an incorrect dtype, it's recommended to + raise an ``AttributeError`` for consistency purposes. In Koalas, ``ValueError`` is more + frequently used to annotate when a value's datatype is unexpected for a given method/function. + + Ultimately, you can structure this however you like, but Koalas would likely do something like + this: + + >>> ks.Series(['a', 'b']).dt + ... + Traceback (most recent call last): + ... + ValueError: Cannot call DatetimeMethods on type StringType + + Examples + -------- + In your library code:: + + import databricks.koalas as ks + + @ks.extensions.register_dataframe_accessor("geo") + class GeoAccessor: + + def __init__(self, koalas_obj): + self._obj = koalas_obj + # other constructor logic + + @property + def center(self): + # return the geographic center point of this DataFrame + lat = self._obj.latitude + lon = self._obj.longitude + return (float(lon.mean()), float(lat.mean())) + + def plot(self): + # plot this array's data on a map + pass + ... + + Then, in an ipython session:: + + >>> import databricks.koalas as ks + >>> from my_ext_lib import GeoAccessor # doctest: +SKIP + >>> type(GeoAccessor) # doctest: +SKIP + + >>> kdf = ks.DataFrame({"longitude": np.linspace(0,10), + ... "latitude": np.linspace(0, 20)}) + >>> kdf.geo.center # doctest: +SKIP + (5.0, 10.0) + + >>> kdf.geo.plot() # doctest: +SKIP + ... + """ + from databricks.koalas import DataFrame + + return _register_accessor(name, DataFrame) + + +def register_series_accessor(name): + """ + Register a custom accessor with a Series object + + Parameters + ---------- + name : str + name used when calling the accessor after its registered + + Returns + ------- + callable + A class decorator. + + See Also + -------- + register_dataframe_accessor: Register a custom accessor on DataFrame objects + register_index_accessor: Register a custom accessor on Index objects + + Notes + ----- + When accessed, your accessor will be initialiazed with the Koalas object the user is interacting + with. The code signature must be:: + + def __init__(self, koalas_obj): + # constructor logic + ... + + In the pandas API, if data passed to your accessor has an incorrect dtype, it's recommended to + raise an ``AttributeError`` for consistency purposes. In Koalas, ``ValueError`` is more + frequently used to annotate when a value's datatype is unexpected for a given method/function. + + Ultimately, you can structure this however you like, but Koalas would likely do something like + this: + + >>> ks.Series(['a', 'b']).dt + ... + Traceback (most recent call last): + ... + ValueError: Cannot call DatetimeMethods on type StringType + + Examples + -------- + In your library code:: + + import databricks.koalas as ks + + @ks.extensions.register_series_accessor("geo") + class GeoAccessor: + + def __init__(self, koalas_obj): + self._obj = koalas_obj + + @property + def is_valid(self): + # boolean check to see if series contains valid geometry + return True if my_validation_logic(self._obj) else False + ... + + Then, in an ipython session:: + + >>> import databricks.koalas as ks + >>> from my_ext_lib import GeoAccessor # doctest: +SKIP + >>> type(GeoAccessor) # doctest: +SKIP + + >>> kdf = ks.DataFrame({"longitude": np.linspace(0,10), + ... "latitude": np.linspace(0, 20)}) + >>> kdf.longitude.geo.is_valid # doctest: +SKIP + True + ... + """ + from databricks.koalas import Series + + return _register_accessor(name, Series) + + +def register_index_accessor(name): + """ + Register a custom accessor with an Index + + Parameters + ---------- + name : str + name used when calling the accessor after its registered + + Returns + ------- + callable + A class decorator. + + See Also + -------- + register_dataframe_accessor: Register a custom accessor on DataFrame objects + register_series_accessor: Register a custom accessor on Series objects + + Notes + ----- + When accessed, your accessor will be initialiazed with the Koalas object the user is interacting + with. The code signature must be:: + + def __init__(self, koalas_obj): + # constructor logic + ... + + In the pandas API, if data passed to your accessor has an incorrect dtype, it's recommended to + raise an ``AttributeError`` for consistency purposes. In Koalas, ``ValueError`` is more + frequently used to annotate when a value's datatype is unexpected for a given method/function. + + Ultimately, you can structure this however you like, but Koalas would likely do something like + this: + + >>> ks.Series(['a', 'b']).dt + ... + Traceback (most recent call last): + ... + ValueError: Cannot call DatetimeMethods on type StringType + + Examples + -------- + In your library code:: + + import databricks.koalas as ks + + @ks.extensions.register_series_accessor("foo") + class CustomAccessor: + + def __init__(self, koalas_obj): + self._obj = koalas_obj + self.item = "baz" + + @property + def bar(self): + # return item value + return self.item + ... + + Then, in an ipython session:: + + >>> import databricks.koalas as ks + >>> from my_ext_lib import CustomAccessor # doctest: +SKIP + >>> type(CustomAccessor) # doctest: +SKIP + + >>> kdf = ks.DataFrame({"longitude": np.linspace(0,10), + ... "latitude": np.linspace(0, 20)}) + >>> kdf.index.foo.bar # doctest: +SKIP + "baz" + ... + """ + from databricks.koalas import Index + + return _register_accessor(name, Index) diff --git a/databricks/koalas/testing/utils.py b/databricks/koalas/testing/utils.py index 53bb345099..8c0c039cea 100644 --- a/databricks/koalas/testing/utils.py +++ b/databricks/koalas/testing/utils.py @@ -18,6 +18,7 @@ import shutil import tempfile import unittest +import warnings from contextlib import contextmanager import pandas as pd @@ -282,3 +283,106 @@ def wrapped(self): compare(result_pandas, result_spark.toPandas()) return wrapped + + +@contextmanager +def assert_produces_warning( + expected_warning=Warning, + filter_level="always", + check_stacklevel=True, + raise_on_extra_warnings=True, +): + """ + Context manager for running code expected to either raise a specific + warning, or not raise any warnings. Verifies that the code raises the + expected warning, and that it does not raise any other unexpected + warnings. It is basically a wrapper around ``warnings.catch_warnings``. + + Notes + ----- + Replicated from pandas._testing. + + Parameters + ---------- + expected_warning : {Warning, False, None}, default Warning + The type of Exception raised. ``exception.Warning`` is the base + class for all warnings. To check that no warning is returned, + specify ``False`` or ``None``. + filter_level : str or None, default "always" + Specifies whether warnings are ignored, displayed, or turned + into errors. + Valid values are: + * "error" - turns matching warnings into exceptions + * "ignore" - discard the warning + * "always" - always emit a warning + * "default" - print the warning the first time it is generated + from each location + * "module" - print the warning the first time it is generated + from each module + * "once" - print the warning the first time it is generated + check_stacklevel : bool, default True + If True, displays the line that called the function containing + the warning to show were the function is called. Otherwise, the + line that implements the function is displayed. + raise_on_extra_warnings : bool, default True + Whether extra warnings not of the type `expected_warning` should + cause the test to fail. + + Examples + -------- + >>> import warnings + >>> with assert_produces_warning(): + ... warnings.warn(UserWarning()) + ... + >>> with assert_produces_warning(False): # doctest: +SKIP + ... warnings.warn(RuntimeWarning()) + ... + Traceback (most recent call last): + ... + AssertionError: Caused unexpected warning(s): ['RuntimeWarning']. + >>> with assert_produces_warning(UserWarning): # doctest: +SKIP + ... warnings.warn(RuntimeWarning()) + Traceback (most recent call last): + ... + AssertionError: Did not see expected warning of class 'UserWarning' + ..warn:: This is *not* thread-safe. + """ + __tracebackhide__ = True + + with warnings.catch_warnings(record=True) as w: + + saw_warning = False + warnings.simplefilter(filter_level) + yield w + extra_warnings = [] + + for actual_warning in w: + if expected_warning and issubclass(actual_warning.category, expected_warning): + saw_warning = True + + if check_stacklevel and issubclass( + actual_warning.category, (FutureWarning, DeprecationWarning) + ): + from inspect import getframeinfo, stack + + caller = getframeinfo(stack()[2][0]) + msg = ( + "Warning not set with correct stacklevel. ", + "File where warning is raised: {} != ".format(actual_warning.filename), + "{}. Warning message: {}".format(caller.filename, actual_warning.message), + ) + assert actual_warning.filename == caller.filename, msg + else: + extra_warnings.append( + ( + actual_warning.category.__name__, + actual_warning.message, + actual_warning.filename, + actual_warning.lineno, + ) + ) + if expected_warning: + msg = "Did not see expected warning of class {}".format(repr(expected_warning.__name__)) + assert saw_warning, msg + if raise_on_extra_warnings and extra_warnings: + raise AssertionError("Caused unexpected warning(s): {}".format(repr(extra_warnings))) diff --git a/databricks/koalas/tests/test_extension.py b/databricks/koalas/tests/test_extension.py new file mode 100644 index 0000000000..e48562dae1 --- /dev/null +++ b/databricks/koalas/tests/test_extension.py @@ -0,0 +1,139 @@ +# +# Copyright (C) 2019 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import contextlib +import unittest +import warnings + +import numpy as np +import pandas as pd + +from databricks import koalas as ks +from databricks.koalas.testing.utils import assert_produces_warning, ReusedSQLTestCase +from databricks.koalas.extensions import ( + register_dataframe_accessor, + register_series_accessor, + register_index_accessor, +) + + +@contextlib.contextmanager +def ensure_removed(obj, attr): + """ + Ensure attribute attached to 'obj' during testing is removed in the end + """ + try: + yield + + finally: + try: + delattr(obj, attr) + except AttributeError: + pass + + +class CustomAccessor: + def __init__(self, obj): + self.obj = obj + self.item = "item" + + @property + def prop(self): + return self.item + + def method(self): + return self.item + + def check_length(self, col=None): + if type(self.obj) == ks.DataFrame or col is not None: + return len(self.obj[col]) + else: + try: + return len(self.obj) + except Exception as e: + raise ValueError(str(e)) + + +class ExtensionTest(ReusedSQLTestCase): + @property + def pdf(self): + return pd.DataFrame( + {"a": [1, 2, 3, 4, 5, 6, 7, 8, 9], "b": [4, 5, 6, 3, 2, 1, 0, 0, 0],}, + index=np.random.rand(9), + ) + + @property + def kdf(self): + return ks.from_pandas(self.pdf) + + @property + def accessor(self): + return CustomAccessor(self.kdf) + + def test_setup(self): + self.assertEqual("item", self.accessor.item) + + def test_dataframe_register(self): + with ensure_removed(ks.DataFrame, "test"): + register_dataframe_accessor("test")(CustomAccessor) + assert self.kdf.test.prop == "item" + assert self.kdf.test.method() == "item" + assert len(self.kdf["a"]) == self.kdf.test.check_length("a") + + def test_series_register(self): + with ensure_removed(ks.Series, "test"): + register_series_accessor("test")(CustomAccessor) + assert self.kdf.a.test.prop == "item" + assert self.kdf.a.test.method() == "item" + assert self.kdf.a.test.check_length() == len(self.kdf["a"]) + + def test_index_register(self): + with ensure_removed(ks.Index, "test"): + register_index_accessor("test")(CustomAccessor) + assert self.kdf.index.test.prop == "item" + assert self.kdf.index.test.method() == "item" + assert self.kdf.index.test.check_length() == self.kdf.index.size + + def test_accessor_works(self): + register_series_accessor("test")(CustomAccessor) + + s = ks.Series([1, 2]) + assert s.test.obj is s + assert s.test.prop == "item" + assert s.test.method() == "item" + + def test_overwrite_warns(self): + mean = ks.Series.mean + try: + with assert_produces_warning(UserWarning) as w: + register_series_accessor("mean")(CustomAccessor) + s = ks.Series([1, 2]) + assert s.mean.prop == "item" + msg = str(w[0].message) + assert "mean" in msg + assert "CustomAccessor" in msg + assert "Series" in msg + finally: + ks.Series.mean = mean + + def test_raises_attr_error(self): + with ensure_removed(ks.Series, "bad"): + + class Bad: + def __init__(self, data): + raise AttributeError("whoops") + + with self.assertRaises(AttributeError): + ks.Series([1, 2], dtype=object).bad diff --git a/docs/source/reference/extensions.rst b/docs/source/reference/extensions.rst new file mode 100644 index 0000000000..1dae2b0da2 --- /dev/null +++ b/docs/source/reference/extensions.rst @@ -0,0 +1,21 @@ +.. _api.extensions: + +========== +Extensions +========== +.. currentmodule:: databricks.koalas.extensions + +Accessors +--------- + +Accessors can be written and registered with Koalas Dataframes, Series, and +Index objects. Accessors allow developers to extend the functionality of +Koalas objects seamlessly by writing arbitrary classes and methods which are +then wrapped in one of the following decorators. + +.. autosummary:: + :toctree: api/ + + register_dataframe_accessor + register_series_accessor + register_index_accessor \ No newline at end of file diff --git a/docs/source/reference/index.rst b/docs/source/reference/index.rst index 34a6f00313..7258a3dd4a 100644 --- a/docs/source/reference/index.rst +++ b/docs/source/reference/index.rst @@ -12,3 +12,4 @@ API Reference window groupby ml + extensions