diff --git a/requirements.txt b/requirements.txt index 214fc54ce6c0..dd8b28a65e39 100644 --- a/requirements.txt +++ b/requirements.txt @@ -34,7 +34,7 @@ flask-talisman==0.7.0 # via apache-superset (setup.py) flask-wtf==0.14.2 # via apache-superset (setup.py), flask-appbuilder flask==1.1.1 # via apache-superset (setup.py), flask-appbuilder, flask-babel, flask-caching, flask-compress, flask-jwt-extended, flask-login, flask-migrate, flask-openid, flask-sqlalchemy, flask-wtf geographiclib==1.50 # via geopy -geopy==1.20.0 # via apache-superset (setup.py) +geopy==1.21.0 # via apache-superset (setup.py) gunicorn==20.0.4 # via apache-superset (setup.py) humanize==0.5.1 # via apache-superset (setup.py) importlib-metadata==1.4.0 # via jsonschema, kombu diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index d7438aaf8bfa..5bb3fe6b4a36 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -265,15 +265,23 @@ class ChartDataSelectOptionsSchema(ChartDataPostProcessingOperationOptionsSchema columns = fields.List( fields.String(), description="Columns which to select from the input data, in the desired " - "order. If columns are renamed, the old column name should be " + "order. If columns are renamed, the original column name should be " "referenced here.", example=["country", "gender", "age"], + required=False, + ) + exclude = fields.List( + fields.String(), + description="Columns to exclude from selection.", + example=["my_temp_column"], + required=False, ) rename = fields.List( fields.Dict(), description="columns which to rename, mapping source column to target column. " "For instance, `{'y': 'y2'}` will rename the column `y` to `y2`.", example=[{"age": "average_age"}], + required=False, ) @@ -335,12 +343,81 @@ class ChartDataPivotOptionsSchema(ChartDataPostProcessingOperationOptionsSchema) aggregates = ChartDataAggregateConfigField() +class ChartDataGeohashDecodeOptionsSchema( + ChartDataPostProcessingOperationOptionsSchema +): + """ + Geohash decode operation config. + """ + + geohash = fields.String( + description="Name of source column containing geohash string", required=True, + ) + latitude = fields.String( + description="Name of target column for decoded latitude", required=True, + ) + longitude = fields.String( + description="Name of target column for decoded longitude", required=True, + ) + + +class ChartDataGeohashEncodeOptionsSchema( + ChartDataPostProcessingOperationOptionsSchema +): + """ + Geohash encode operation config. + """ + + latitude = fields.String( + description="Name of source latitude column", required=True, + ) + longitude = fields.String( + description="Name of source longitude column", required=True, + ) + geohash = fields.String( + description="Name of target column for encoded geohash string", required=True, + ) + + +class ChartDataGeodeticParseOptionsSchema( + ChartDataPostProcessingOperationOptionsSchema +): + """ + Geodetic point string parsing operation config. + """ + + geodetic = fields.String( + description="Name of source column containing geodetic point strings", + required=True, + ) + latitude = fields.String( + description="Name of target column for decoded latitude", required=True, + ) + longitude = fields.String( + description="Name of target column for decoded longitude", required=True, + ) + altitude = fields.String( + description="Name of target column for decoded altitude. If omitted, " + "altitude information in geodetic string is ignored.", + required=False, + ) + + class ChartDataPostProcessingOperationSchema(Schema): operation = fields.String( description="Post processing operation type", required=True, validate=validate.OneOf( - choices=("aggregate", "pivot", "rolling", "select", "sort") + choices=( + "aggregate", + "geodetic_parse", + "geohash_decode", + "geohash_encode", + "pivot", + "rolling", + "select", + "sort", + ) ), example="aggregate", ) @@ -638,4 +715,7 @@ class ChartDataResponseSchema(Schema): ChartDataRollingOptionsSchema, ChartDataSelectOptionsSchema, ChartDataSortOptionsSchema, + ChartDataGeohashDecodeOptionsSchema, + ChartDataGeohashEncodeOptionsSchema, + ChartDataGeodeticParseOptionsSchema, ) diff --git a/superset/utils/pandas_postprocessing.py b/superset/utils/pandas_postprocessing.py index f2a688c252ee..dabebed785e3 100644 --- a/superset/utils/pandas_postprocessing.py +++ b/superset/utils/pandas_postprocessing.py @@ -15,10 +15,12 @@ # specific language governing permissions and limitations # under the License. from functools import partial -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import geohash as geohash_lib import numpy as np from flask_babel import gettext as _ +from geopy.point import Point from pandas import DataFrame, NamedAgg from superset.exceptions import QueryObjectValidationError @@ -144,10 +146,7 @@ def _append_columns( :return: new DataFrame with combined data from `base_df` and `append_df` """ return base_df.assign( - **{ - target: append_df[append_df.columns[idx]] - for idx, target in enumerate(columns.values()) - } + **{target: append_df[source] for source, target in columns.items()} ) @@ -323,9 +322,12 @@ def rolling( # pylint: disable=too-many-arguments return df -@validate_column_args("columns", "rename") +@validate_column_args("columns", "drop", "rename") def select( - df: DataFrame, columns: List[str], rename: Optional[Dict[str, str]] = None + df: DataFrame, + columns: Optional[List[str]] = None, + exclude: Optional[List[str]] = None, + rename: Optional[Dict[str, str]] = None, ) -> DataFrame: """ Only select a subset of columns in the original dataset. Can be useful for @@ -333,15 +335,21 @@ def select( :param df: DataFrame on which the rolling period will be based. :param columns: Columns which to select from the DataFrame, in the desired order. - If columns are renamed, the old column name should be referenced - here. + If left undefined, all columns will be selected. If columns are + renamed, the original column name should be referenced here. + :param exclude: columns to exclude from selection. If columns are renamed, the new + column name should be referenced here. :param rename: columns which to rename, mapping source column to target column. For instance, `{'y': 'y2'}` will rename the column `y` to `y2`. :return: Subset of columns in original DataFrame :raises ChartDataValidationError: If the request in incorrect """ - df_select = df[columns] + df_select = df.copy(deep=False) + if columns: + df_select = df_select[columns] + if exclude: + df_select = df_select.drop(exclude, axis=1) if rename is not None: df_select = df_select.rename(columns=rename) return df_select @@ -350,6 +358,7 @@ def select( @validate_column_args("columns") def diff(df: DataFrame, columns: Dict[str, str], periods: int = 1,) -> DataFrame: """ + Calculate row-by-row difference for select columns. :param df: DataFrame on which the diff will be based. :param columns: columns on which to perform diff, mapping source column to @@ -369,6 +378,7 @@ def diff(df: DataFrame, columns: Dict[str, str], periods: int = 1,) -> DataFrame @validate_column_args("columns") def cum(df: DataFrame, columns: Dict[str, str], operator: str) -> DataFrame: """ + Calculate cumulative sum/product/min/max for select columns. :param df: DataFrame on which the cumulative operation will be based. :param columns: columns on which to perform a cumulative operation, mapping source @@ -377,7 +387,7 @@ def cum(df: DataFrame, columns: Dict[str, str], operator: str) -> DataFrame: `y2` based on cumulative values calculated from `y`, leaving the original column `y` unchanged. :param operator: cumulative operator, e.g. `sum`, `prod`, `min`, `max` - :return: + :return: DataFrame with cumulated columns """ df_cum = df[columns.keys()] operation = "cum" + operator @@ -388,3 +398,92 @@ def cum(df: DataFrame, columns: Dict[str, str], operator: str) -> DataFrame: _("Invalid cumulative operator: %(operator)s", operator=operator) ) return _append_columns(df, getattr(df_cum, operation)(), columns) + + +def geohash_decode( + df: DataFrame, geohash: str, longitude: str, latitude: str +) -> DataFrame: + """ + Decode a geohash column into longitude and latitude + + :param df: DataFrame containing geohash data + :param geohash: Name of source column containing geohash location. + :param longitude: Name of new column to be created containing longitude. + :param latitude: Name of new column to be created containing latitude. + :return: DataFrame with decoded longitudes and latitudes + """ + try: + lonlat_df = DataFrame() + lonlat_df["latitude"], lonlat_df["longitude"] = zip( + *df[geohash].apply(geohash_lib.decode) + ) + return _append_columns( + df, lonlat_df, {"latitude": latitude, "longitude": longitude} + ) + except ValueError: + raise QueryObjectValidationError(_("Invalid geohash string")) + + +def geohash_encode( + df: DataFrame, geohash: str, longitude: str, latitude: str, +) -> DataFrame: + """ + Encode longitude and latitude into geohash + + :param df: DataFrame containing longitude and latitude data + :param geohash: Name of new column to be created containing geohash location. + :param longitude: Name of source column containing longitude. + :param latitude: Name of source column containing latitude. + :return: DataFrame with decoded longitudes and latitudes + """ + try: + encode_df = df[[latitude, longitude]] + encode_df.columns = ["latitude", "longitude"] + encode_df["geohash"] = encode_df.apply( + lambda row: geohash_lib.encode(row["latitude"], row["longitude"]), axis=1, + ) + return _append_columns(df, encode_df, {"geohash": geohash}) + except ValueError: + QueryObjectValidationError(_("Invalid longitude/latitude")) + + +def geodetic_parse( + df: DataFrame, + geodetic: str, + longitude: str, + latitude: str, + altitude: Optional[str] = None, +) -> DataFrame: + """ + Parse a column containing a geodetic point string + [Geopy](https://geopy.readthedocs.io/en/stable/#geopy.point.Point). + + :param df: DataFrame containing geodetic point data + :param geodetic: Name of source column containing geodetic point string. + :param longitude: Name of new column to be created containing longitude. + :param latitude: Name of new column to be created containing latitude. + :param altitude: Name of new column to be created containing altitude. + :return: DataFrame with decoded longitudes and latitudes + """ + + def _parse_location(location: str) -> Tuple[float, float, float]: + """ + Parse a string containing a geodetic point and return latitude, longitude + and altitude + """ + point = Point(location) # type: ignore + return point[0], point[1], point[2] + + try: + geodetic_df = DataFrame() + ( + geodetic_df["latitude"], + geodetic_df["longitude"], + geodetic_df["altitude"], + ) = zip(*df[geodetic].apply(_parse_location)) + columns = {"latitude": latitude, "longitude": longitude} + if altitude: + columns["altitude"] = altitude + return _append_columns(df, geodetic_df, columns) + except ValueError: + raise QueryObjectValidationError(_("Invalid geodetic string")) diff --git a/tests/fixtures/dataframes.py b/tests/fixtures/dataframes.py index e565dc40a002..dd01085a18a4 100644 --- a/tests/fixtures/dataframes.py +++ b/tests/fixtures/dataframes.py @@ -119,3 +119,17 @@ index=to_datetime(["2019-01-01", "2019-01-02", "2019-01-05", "2019-01-07"]), data={"label": ["x", "y", "z", "q"], "y": [1.0, 2.0, 3.0, 4.0]}, ) + +lonlat_df = DataFrame( + { + "city": ["New York City", "Sydney"], + "geohash": ["dr5regw3pg6f", "r3gx2u9qdevk"], + "latitude": [40.71277496, -33.85598011], + "longitude": [-74.00597306, 151.20666526], + "altitude": [5.5, 0.012], + "geodetic": [ + "40.71277496, -74.00597306, 5.5km", + "-33.85598011, 151.20666526, 12m", + ], + } +) diff --git a/tests/pandas_postprocessing_tests.py b/tests/pandas_postprocessing_tests.py index a981477d0bb4..14342cc895dc 100644 --- a/tests/pandas_postprocessing_tests.py +++ b/tests/pandas_postprocessing_tests.py @@ -16,7 +16,7 @@ # under the License. # isort:skip_file import math -from typing import Any, List +from typing import Any, List, Optional from pandas import Series @@ -24,7 +24,7 @@ from superset.utils import pandas_postprocessing as proc from .base_tests import SupersetTestCase -from .fixtures.dataframes import categories_df, timeseries_df +from .fixtures.dataframes import categories_df, lonlat_df, timeseries_df def series_to_list(series: Series) -> List[Any]: @@ -43,6 +43,19 @@ def series_to_list(series: Series) -> List[Any]: ] +def round_floats( + floats: List[Optional[float]], precision: int +) -> List[Optional[float]]: + """ + Round list of floats to certain precision + + :param floats: floats to round + :param precision: intended decimal precision + :return: rounded floats + """ + return [round(val, precision) if val else None for val in floats] + + class PostProcessingTestCase(SupersetTestCase): def test_pivot(self): aggregates = {"idx_nulls": {"operator": "sum"}} @@ -219,25 +232,40 @@ def test_select(self): post_df = proc.select(df=timeseries_df, columns=["label"]) self.assertListEqual(post_df.columns.tolist(), ["label"]) - # rename one column + # rename and select one column post_df = proc.select(df=timeseries_df, columns=["y"], rename={"y": "y1"}) self.assertListEqual(post_df.columns.tolist(), ["y1"]) # rename one and leave one unchanged - post_df = proc.select( - df=timeseries_df, columns=["label", "y"], rename={"y": "y1"} - ) + post_df = proc.select(df=timeseries_df, rename={"y": "y1"}) self.assertListEqual(post_df.columns.tolist(), ["label", "y1"]) + # drop one column + post_df = proc.select(df=timeseries_df, exclude=["label"]) + self.assertListEqual(post_df.columns.tolist(), ["y"]) + + # rename and drop one column + post_df = proc.select(df=timeseries_df, rename={"y": "y1"}, exclude=["label"]) + self.assertListEqual(post_df.columns.tolist(), ["y1"]) + # invalid columns self.assertRaises( QueryObjectValidationError, proc.select, df=timeseries_df, - columns=["qwerty"], + columns=["abc"], rename={"abc": "qwerty"}, ) + # select renamed column by new name + self.assertRaises( + QueryObjectValidationError, + proc.select, + df=timeseries_df, + columns=["label_new"], + rename={"label": "label_new"}, + ) + def test_diff(self): # overwrite column post_df = proc.diff(df=timeseries_df, columns={"y": "y"}) @@ -288,3 +316,83 @@ def test_cum(self): columns={"y": "y"}, operator="abc", ) + + def test_geohash_decode(self): + # decode lon/lat from geohash + post_df = proc.geohash_decode( + df=lonlat_df[["city", "geohash"]], + geohash="geohash", + latitude="latitude", + longitude="longitude", + ) + self.assertListEqual( + sorted(post_df.columns.tolist()), + sorted(["city", "geohash", "latitude", "longitude"]), + ) + self.assertListEqual( + round_floats(series_to_list(post_df["longitude"]), 6), + round_floats(series_to_list(lonlat_df["longitude"]), 6), + ) + self.assertListEqual( + round_floats(series_to_list(post_df["latitude"]), 6), + round_floats(series_to_list(lonlat_df["latitude"]), 6), + ) + + def test_geohash_encode(self): + # encode lon/lat into geohash + post_df = proc.geohash_encode( + df=lonlat_df[["city", "latitude", "longitude"]], + latitude="latitude", + longitude="longitude", + geohash="geohash", + ) + self.assertListEqual( + sorted(post_df.columns.tolist()), + sorted(["city", "geohash", "latitude", "longitude"]), + ) + self.assertListEqual( + series_to_list(post_df["geohash"]), series_to_list(lonlat_df["geohash"]), + ) + + def test_geodetic_parse(self): + # parse geodetic string with altitude into lon/lat/altitude + post_df = proc.geodetic_parse( + df=lonlat_df[["city", "geodetic"]], + geodetic="geodetic", + latitude="latitude", + longitude="longitude", + altitude="altitude", + ) + self.assertListEqual( + sorted(post_df.columns.tolist()), + sorted(["city", "geodetic", "latitude", "longitude", "altitude"]), + ) + self.assertListEqual( + series_to_list(post_df["longitude"]), + series_to_list(lonlat_df["longitude"]), + ) + self.assertListEqual( + series_to_list(post_df["latitude"]), series_to_list(lonlat_df["latitude"]), + ) + self.assertListEqual( + series_to_list(post_df["altitude"]), series_to_list(lonlat_df["altitude"]), + ) + + # parse geodetic string into lon/lat + post_df = proc.geodetic_parse( + df=lonlat_df[["city", "geodetic"]], + geodetic="geodetic", + latitude="latitude", + longitude="longitude", + ) + self.assertListEqual( + sorted(post_df.columns.tolist()), + sorted(["city", "geodetic", "latitude", "longitude"]), + ) + self.assertListEqual( + series_to_list(post_df["longitude"]), + series_to_list(lonlat_df["longitude"]), + ) + self.assertListEqual( + series_to_list(post_df["latitude"]), series_to_list(lonlat_df["latitude"]), + )