Skip to content

Commit

Permalink
fix accept decorator bug (#90)
Browse files Browse the repository at this point in the history
* fix accept decorator bug

* add tests for previous exception cases

* remove duplicate code

---------

Co-authored-by: Keegan Cordeiro <[email protected]>
  • Loading branch information
corke2013 and Keegan Cordeiro authored Mar 27, 2024
1 parent 036d02d commit c4e26a8
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 78 deletions.
96 changes: 48 additions & 48 deletions predicthq/endpoints/decorators.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,73 @@
import functools
from collections import defaultdict

from pydantic import ValidationError as PydanticValidationError

from predicthq.endpoints.schemas import ResultSet
from predicthq.exceptions import ValidationError


def _to_url_params(data, glue=".", separator=","):
def _kwargs_to_key_list_mapping(kwargs, separator="__"):
"""
Converts data dictionary to url parameters
Converts kwargs to a nested dictionary mapping keys to lists of values
"""
params = {}
for key, value in data.items():
if isinstance(value, bool):
params[key] = 1 if value else 0
elif isinstance(value, list):
params[key] = separator.join(map(str, value))
elif isinstance(value, dict):
params.update(_flatten_dict(value, glue, separator, parent_key=key))
else:
params[key] = value
return params


def _flatten_dict(d, glue, separator, parent_key=""):
flat_dict = {}
for k, v in d.items():
if isinstance(v, dict):
flat_dict.update(_flatten_dict(v, glue, separator, f"{parent_key}{glue}{k}" if parent_key else k))
continue
if isinstance(v, list):
flat_dict.update({f"{parent_key}{glue}{k}" if parent_key else k: separator.join(map(str, v))})
continue
flat_dict.update({f"{parent_key}{glue}{k}" if parent_key else k: v})
return flat_dict

data = {}
for key, value in kwargs.items():
keys = key.split(separator, 1)
if len(keys) > 1:
value = {keys[1]: value}
if isinstance(value, dict):
value = _kwargs_to_key_list_mapping(value)

data[keys[0]] = [] if not data.get(keys[0]) else data[keys[0]]
data[keys[0]].append(value)
return data

def _assign_nested_key(parent_dict, keys, value):
current_key = keys[0]
if len(keys) > 1:
if current_key not in parent_dict:
parent_dict[current_key] = dict()
_assign_nested_key(parent_dict[current_key], keys[1:], value)
else:
parent_dict[current_key] = value

def _to_url_params(key_list_mapping, glue=".", separator=",", parent_key=""):
"""
Converts key_list_mapping to url parameters
"""
params = {}
for key, value in key_list_mapping.items():
current_key = f"{parent_key}{glue}{key}" if parent_key else key
for v in value:
if isinstance(v, dict):
params.update(_to_url_params(v, glue, separator, current_key))
elif isinstance(v, list):
params.update({current_key: separator.join(map(str, v))})
elif isinstance(v, bool):
params.update({current_key: 1 if v else 0})
else:
params.update({current_key: v})
return params

def _process_kwargs(kwargs, separator="__"):
data = dict()
for key, value in kwargs.items():
if separator in key:
_assign_nested_key(data, key.split(separator), value)
else:
data[key] = value
return data

def _to_json(key_list_mapping):
"""
Converts key_list_mapping to json
"""
json = {}
for key, value in key_list_mapping.items():
for v in value:
json[key] = dict() if not json.get(key) else json[key]
if isinstance(v, dict):
json[key].update(_to_json(v))
else:
json[key] = v
return json


def accepts(query_string=True, role=None):
def decorator(f):
@functools.wraps(f)
def wrapper(endpoint, *args, **kwargs):
data = _process_kwargs(kwargs)
key_list_mapping = _kwargs_to_key_list_mapping(kwargs)
if hasattr(endpoint, "mutate_bool_to_default_for_type"):
endpoint.mutate_bool_to_default_for_type(data)
endpoint.mutate_bool_to_default_for_type(key_list_mapping)

if query_string:
data = _to_url_params(data=data)
data = _to_url_params(key_list_mapping)
else:
data = _to_json(key_list_mapping)

return f(endpoint, *args, **data)

Expand Down
6 changes: 3 additions & 3 deletions predicthq/endpoints/v1/features/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class FeaturesEndpoint(UserBaseEndpoint):

BASE_FEATURE_CRITERIA = {"stats": ["sum", "count"], "phq_rank": None}
BASE_FEATURE_CRITERIA = {"stats": [["sum", "count"]], "phq_rank": [None]}
FIELDS_TO_MUTATE = frozenset([
"phq_attendance_",
"phq_viewership_sports",
Expand All @@ -18,8 +18,8 @@ class FeaturesEndpoint(UserBaseEndpoint):
@classmethod
def mutate_bool_to_default_for_type(cls, user_request_spec):
for key, val in user_request_spec.items():
if any(key.startswith(x) for x in cls.FIELDS_TO_MUTATE) and isinstance(val, bool):
user_request_spec[key] = cls.BASE_FEATURE_CRITERIA
if any(key.startswith(x) for x in cls.FIELDS_TO_MUTATE):
user_request_spec[key] = [cls.BASE_FEATURE_CRITERIA if isinstance(v, bool) else v for v in val]

@accepts(query_string=False)
@returns(FeatureResultSet)
Expand Down
2 changes: 1 addition & 1 deletion predicthq/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.2.0"
__version__ = "3.3.0"
107 changes: 81 additions & 26 deletions tests/endpoints/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,80 @@
from predicthq.exceptions import ValidationError


def test_to_params():
kwargs = {
"string_type": "my-string",
"list_type": [1, 2, 3],
"dict_type": {"key1": "val1", "key2": "val2"},
"bool_type": True,
"nested_dict_type": {"key1": {"key2": "val2", "key3": "val3"}},
}
expected = {
"string_type": "my-string",
"list_type": "1,2,3",
"dict_type.key1": u"val1",
"dict_type.key2": "val2",
"bool_type": 1,
"nested_dict_type.key1.key2": "val2",
"nested_dict_type.key1.key3": "val3",
}
@pytest.mark.parametrize("kwargs, expected", [
(
{
"normal_arg": ["value1"],
"nested": [{"arg": ["value2"]}, {"arg2": ["value3"]}],
"multiple": [{"level": [{"nested": ["value4"]}]}, {"level": [{"nested2": ["value5"]}]},
{"lev": [{"nested": ["value6"]}]}, {"lev2": [{"nested": ["value7"]}]}]
},
{
"normal_arg": "value1",
"nested.arg": "value2",
"nested.arg2": "value3",
"multiple.level.nested": "value4",
"multiple.level.nested2": "value5",
"multiple.lev.nested": "value6",
"multiple.lev2.nested": "value7",
}
),
(
{
"string_type": ["my-string"],
"list_type": [[1, 2, 3]],
"dict_type": [{"key1": ["val1"], "key2": ["val2"]}],
"bool_type": [True],
"nested_dict_type": [{"key1": [{"key2": ["val2"], "key3": ["val3"]}]}],
},
{
"string_type": "my-string",
"list_type": "1,2,3",
"dict_type.key1": u"val1",
"dict_type.key2": "val2",
"bool_type": 1,
"nested_dict_type.key1.key2": "val2",
"nested_dict_type.key1.key3": "val3",
}
)
])
def test_to_params(kwargs, expected):
assert decorators._to_url_params(kwargs) == expected


def test_kwargs_processor():
kwargs = {"normal_arg": "value", "nested__arg": "value", "multiple__level__nested": "value"}
expected = {"normal_arg": "value", "nested": {"arg": "value"}, "multiple": {"level": {"nested": "value"}}}
assert decorators._process_kwargs(kwargs) == expected
@pytest.mark.parametrize("kwargs, expected", [
(
{
"normal_arg": "value1",
"nested__arg": "value2",
"multiple__level__nested": "value3"
},
{
"normal_arg": ["value1"],
"nested": [{"arg": ["value2"]}],
"multiple": [{"level": [{"nested": ["value3"]}]}]
}
),
(
{
"normal_arg": "value1",
"nested__arg": "value2",
"nested__arg2": "value3",
"multiple__level__nested": "value4",
"multiple__level__nested2": "value5",
"multiple__lev__nested": "value6",
"multiple__lev2__nested": "value7"
},
{
"normal_arg": ["value1"],
"nested": [{"arg": ["value2"]}, {"arg2": ["value3"]}],
"multiple": [{"level": [{"nested": ["value4"]}]}, {"level": [{"nested2": ["value5"]}]},
{"lev": [{"nested": ["value6"]}]}, {"lev2": [{"nested": ["value7"]}]}]
}
)
])
def test_kwargs_processor(kwargs, expected):
assert decorators._kwargs_to_key_list_mapping(kwargs) == expected


def test_accepts():
Expand Down Expand Up @@ -66,7 +116,8 @@ def func(self, **kwargs):
return kwargs

endpoint = EndpointExample(None)
assert endpoint.func(arg1="test", arg2=[1, 2]).model_dump(exclude_none=True) == SchemaExample(**{"arg1": "test", "arg2": [1, 2]}).model_dump(exclude_none=True)
assert endpoint.func(arg1="test", arg2=[1, 2]).model_dump(exclude_none=True) == SchemaExample(
**{"arg1": "test", "arg2": [1, 2]}).model_dump(exclude_none=True)

with pytest.raises(ValidationError):
endpoint.func(arg2=[1, 2])
Expand All @@ -85,8 +136,10 @@ def func(self, **kwargs):
return kwargs

endpoint = EndpointExample(None)
assert endpoint.func(results=["item1", "item2"]).model_dump(exclude_none=True) == SchemaExample(**{"results": ["item1", "item2"]}).model_dump(exclude_none=True)
assert endpoint.func()._more(results=["item3", "item4"]).model_dump(exclude_none=True) == SchemaExample(**{"results": ["item3", "item4"]}).model_dump(exclude_none=True)
assert endpoint.func(results=["item1", "item2"]).model_dump(exclude_none=True) == SchemaExample(
**{"results": ["item1", "item2"]}).model_dump(exclude_none=True)
assert endpoint.func()._more(results=["item3", "item4"]).model_dump(exclude_none=True) == SchemaExample(
**{"results": ["item3", "item4"]}).model_dump(exclude_none=True)
assert endpoint == endpoint.func()._endpoint


Expand All @@ -104,7 +157,9 @@ def func(self, **kwargs):

endpoint = EndpointExample(None)
results = endpoint.func(results=[{"name": "item1"}, {"name": "item2"}])
assert results.model_dump(exclude_none=True) == SchemaExample(**{"results": [{"name": "item1"}, {"name": "item2"}]}).model_dump(exclude_none=True)
assert endpoint.func()._more(results=[{"name": "item2"}, {"name": "item4"}]).model_dump(exclude_none=True) == SchemaExample(
assert results.model_dump(exclude_none=True) == SchemaExample(
**{"results": [{"name": "item1"}, {"name": "item2"}]}).model_dump(exclude_none=True)
assert endpoint.func()._more(results=[{"name": "item2"}, {"name": "item4"}]).model_dump(
exclude_none=True) == SchemaExample(
**{"results": [{"name": "item2"}, {"name": "item4"}]}
).model_dump(exclude_none=True)

0 comments on commit c4e26a8

Please sign in to comment.