From 1fa360ed388fd1a4e734923e5966ee1c6f0a9899 Mon Sep 17 00:00:00 2001 From: iamsudip Date: Sun, 1 Sep 2019 04:34:29 +0530 Subject: [PATCH 1/3] enforce required arguments, fixes #72 --- thriftpy2/thrift.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/thriftpy2/thrift.py b/thriftpy2/thrift.py index 1b2e3cff..a614dbb1 100644 --- a/thriftpy2/thrift.py +++ b/thriftpy2/thrift.py @@ -191,10 +191,22 @@ def __getattr__(self, _api): def __dir__(self): return self._service.thrift_services + def _validate_required_args(self, _api, api_args): + thrift_spec = getattr(self._service, _api + "_args").thrift_spec + for item in thrift_spec.items(): + arg = item[1][1] + required = item[1][2] + if required and arg not in api_args: + raise TApplicationException( + TApplicationException.UNKNOWN_METHOD, + '{arg} is required argument for {service}.{api}'.format( + arg=arg, service=self._service, api=_api)) + def _req(self, _api, *args, **kwargs): _kw = args2kwargs(getattr(self._service, _api + "_args").thrift_spec, *args) kwargs.update(_kw) + self._validate_required_args(_api, kwargs) result_cls = getattr(self._service, _api + "_result") self._send(_api, **kwargs) From 5392c804c4fb3a0bf83510ab98cbcfd39fbdaf07 Mon Sep 17 00:00:00 2001 From: iamsudip Date: Sun, 1 Sep 2019 05:14:49 +0530 Subject: [PATCH 2/3] faster but uglier implementation --- thriftpy2/contrib/aio/client.py | 11 +++++++--- thriftpy2/thrift.py | 38 ++++++++++++++++++--------------- 2 files changed, 29 insertions(+), 20 deletions(-) diff --git a/thriftpy2/contrib/aio/client.py b/thriftpy2/contrib/aio/client.py index 23b5972e..48544722 100644 --- a/thriftpy2/contrib/aio/client.py +++ b/thriftpy2/contrib/aio/client.py @@ -26,9 +26,14 @@ def __dir__(self): @asyncio.coroutine def _req(self, _api, *args, **kwargs): - _kw = args2kwargs(getattr(self._service, _api + "_args").thrift_spec, - *args) - kwargs.update(_kw) + try: + kwargs = args2kwargs(getattr(self._service, _api + "_args").thrift_spec, + *args, **kwargs) + except ValueError as e: + raise TApplicationException( + TApplicationException.UNKNOWN_METHOD, + 'missing required argument {arg} for {service}.{api}'.format( + arg=e.args[0], service=self._service, api=_api)) result_cls = getattr(self._service, _api + "_result") yield from self._send(_api, **kwargs) diff --git a/thriftpy2/thrift.py b/thriftpy2/thrift.py index a614dbb1..f97aff2d 100644 --- a/thriftpy2/thrift.py +++ b/thriftpy2/thrift.py @@ -13,12 +13,22 @@ import linecache import types -from ._compat import with_metaclass +from ._compat import with_metaclass, PY3 +if PY3: + from itertools import zip_longest +else: + from itertools import izip_longest as zip_longest -def args2kwargs(thrift_spec, *args): - arg_names = [item[1][1] for item in sorted(thrift_spec.items())] - return dict(zip(arg_names, args)) +def args2kwargs(thrift_spec, *args, **kwargs): + for item, value in zip_longest(sorted(thrift_spec.items()), args): + arg_name = item[1][1] + required = item[1][2] + if value is not None: + kwargs[item[1][1]] = value + if required and arg_name not in kwargs: + raise ValueError(arg_name) + return kwargs def parse_spec(ttype, spec=None): @@ -191,22 +201,16 @@ def __getattr__(self, _api): def __dir__(self): return self._service.thrift_services - def _validate_required_args(self, _api, api_args): - thrift_spec = getattr(self._service, _api + "_args").thrift_spec - for item in thrift_spec.items(): - arg = item[1][1] - required = item[1][2] - if required and arg not in api_args: - raise TApplicationException( + def _req(self, _api, *args, **kwargs): + try: + kwargs = args2kwargs(getattr(self._service, _api + "_args").thrift_spec, + *args, **kwargs) + except ValueError as e: + raise TApplicationException( TApplicationException.UNKNOWN_METHOD, '{arg} is required argument for {service}.{api}'.format( - arg=arg, service=self._service, api=_api)) + arg=e.args[0], service=self._service, api=_api)) - def _req(self, _api, *args, **kwargs): - _kw = args2kwargs(getattr(self._service, _api + "_args").thrift_spec, - *args) - kwargs.update(_kw) - self._validate_required_args(_api, kwargs) result_cls = getattr(self._service, _api + "_result") self._send(_api, **kwargs) From 74b93d0bcbd7c1ea63ac8b0ad254d866257c2119 Mon Sep 17 00:00:00 2001 From: iamsudip Date: Sun, 1 Sep 2019 13:48:43 +0530 Subject: [PATCH 3/3] add tests and make function name pythonic --- tests/addressbook.thrift | 2 +- tests/test_aio.py | 11 +++++++++++ tests/test_http.py | 9 +++++++++ tests/test_rpc.py | 8 ++++++++ thriftpy2/contrib/aio/client.py | 6 +++--- thriftpy2/thrift.py | 8 ++++---- 6 files changed, 36 insertions(+), 8 deletions(-) diff --git a/tests/addressbook.thrift b/tests/addressbook.thrift index bf36a171..96957951 100644 --- a/tests/addressbook.thrift +++ b/tests/addressbook.thrift @@ -34,7 +34,7 @@ exception PersonNotExistsError { service AddressBookService { void ping(); - string hello(1: string name); + string hello(1: required string name); bool add(1: Person person); bool remove(1: string name) throws (1: PersonNotExistsError not_exists); Person get(1: string name) throws (1: PersonNotExistsError not_exists); diff --git a/tests/test_aio.py b/tests/test_aio.py index d6f3fbd4..8e0f8443 100644 --- a/tests/test_aio.py +++ b/tests/test_aio.py @@ -16,6 +16,7 @@ from thriftpy2.rpc import make_aio_server, make_aio_client # noqa from thriftpy2.transport import TTransportException # noqa +from thriftpy2.thrift import TApplicationException # noqa addressbook = thriftpy2.load(os.path.join(os.path.dirname(__file__), "addressbook.thrift")) @@ -164,6 +165,16 @@ async def test_string_api(aio_server): c.close() +@pytest.mark.asyncio +async def test_required_argument(aio_server): + c = await client() + assert await c.hello("") == "hello " + + with pytest.raises(TApplicationException): + await c.hello() + c.close() + + @pytest.mark.asyncio async def test_string_api_with_ssl(aio_ssl_server): c = await client() diff --git a/tests/test_http.py b/tests/test_http.py index c3adfe10..df5ed5ac 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -13,6 +13,7 @@ thriftpy2.install_import_hook() # noqa from thriftpy2.http import make_server, client_context +from thriftpy2.thrift import TApplicationException addressbook = thriftpy2.load(os.path.join(os.path.dirname(__file__), @@ -115,6 +116,14 @@ def test_string_api(server): assert c.hello("world") == "hello world" +def test_required_argument(server): + with client() as c: + with pytest.raises(TApplicationException): + c.hello() + + assert c.hello(name="") == "hello " + + def test_huge_res(server): with client() as c: big_str = "world" * 100000 diff --git a/tests/test_rpc.py b/tests/test_rpc.py index 8b679d10..ca697aef 100644 --- a/tests/test_rpc.py +++ b/tests/test_rpc.py @@ -17,6 +17,7 @@ from thriftpy2._compat import PY3 # noqa from thriftpy2.rpc import make_server, client_context # noqa from thriftpy2.transport import TTransportException # noqa +from thriftpy2.thrift import TApplicationException # noqa addressbook = thriftpy2.load(os.path.join(os.path.dirname(__file__), @@ -158,6 +159,13 @@ def test_string_api(server): assert c.hello("world") == "hello world" +def test_required_argument(server): + with client() as c: + assert c.hello("") == "hello " + with pytest.raises(TApplicationException): + c.hello() + + def test_string_api_with_ssl(ssl_server): with ssl_client() as c: assert c.hello("world") == "hello world" diff --git a/thriftpy2/contrib/aio/client.py b/thriftpy2/contrib/aio/client.py index 48544722..9daa983f 100644 --- a/thriftpy2/contrib/aio/client.py +++ b/thriftpy2/contrib/aio/client.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- import asyncio import functools -from thriftpy2.thrift import args2kwargs +from thriftpy2.thrift import args_to_kwargs from thriftpy2.thrift import TApplicationException, TMessageType @@ -27,13 +27,13 @@ def __dir__(self): @asyncio.coroutine def _req(self, _api, *args, **kwargs): try: - kwargs = args2kwargs(getattr(self._service, _api + "_args").thrift_spec, + kwargs = args_to_kwargs(getattr(self._service, _api + "_args").thrift_spec, *args, **kwargs) except ValueError as e: raise TApplicationException( TApplicationException.UNKNOWN_METHOD, 'missing required argument {arg} for {service}.{api}'.format( - arg=e.args[0], service=self._service, api=_api)) + arg=e.args[0], service=self._service.__name__, api=_api)) result_cls = getattr(self._service, _api + "_result") yield from self._send(_api, **kwargs) diff --git a/thriftpy2/thrift.py b/thriftpy2/thrift.py index f97aff2d..381595a1 100644 --- a/thriftpy2/thrift.py +++ b/thriftpy2/thrift.py @@ -20,10 +20,10 @@ from itertools import izip_longest as zip_longest -def args2kwargs(thrift_spec, *args, **kwargs): +def args_to_kwargs(thrift_spec, *args, **kwargs): for item, value in zip_longest(sorted(thrift_spec.items()), args): arg_name = item[1][1] - required = item[1][2] + required = item[1][-1] if value is not None: kwargs[item[1][1]] = value if required and arg_name not in kwargs: @@ -203,13 +203,13 @@ def __dir__(self): def _req(self, _api, *args, **kwargs): try: - kwargs = args2kwargs(getattr(self._service, _api + "_args").thrift_spec, + kwargs = args_to_kwargs(getattr(self._service, _api + "_args").thrift_spec, *args, **kwargs) except ValueError as e: raise TApplicationException( TApplicationException.UNKNOWN_METHOD, '{arg} is required argument for {service}.{api}'.format( - arg=e.args[0], service=self._service, api=_api)) + arg=e.args[0], service=self._service.__name__, api=_api)) result_cls = getattr(self._service, _api + "_result")