Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enforce required arguments, fixes #72 #81

Merged
merged 3 commits into from
Sep 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/addressbook.thrift
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ exception PersonNotExistsError {

service AddressBookService {
void ping();
string hello(1: string name);
string hello(1: required string name);
bool add(1: Person person);
bool remove(1: string name) throws (1: PersonNotExistsError not_exists);
Person get(1: string name) throws (1: PersonNotExistsError not_exists);
Expand Down
11 changes: 11 additions & 0 deletions tests/test_aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from thriftpy2.rpc import make_aio_server, make_aio_client # noqa
from thriftpy2.transport import TTransportException # noqa
from thriftpy2.thrift import TApplicationException # noqa

addressbook = thriftpy2.load(os.path.join(os.path.dirname(__file__),
"addressbook.thrift"))
Expand Down Expand Up @@ -164,6 +165,16 @@ async def test_string_api(aio_server):
c.close()


@pytest.mark.asyncio
async def test_required_argument(aio_server):
c = await client()
assert await c.hello("") == "hello "

with pytest.raises(TApplicationException):
await c.hello()
c.close()


@pytest.mark.asyncio
async def test_string_api_with_ssl(aio_ssl_server):
c = await client()
Expand Down
9 changes: 9 additions & 0 deletions tests/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
thriftpy2.install_import_hook() # noqa

from thriftpy2.http import make_server, client_context
from thriftpy2.thrift import TApplicationException


addressbook = thriftpy2.load(os.path.join(os.path.dirname(__file__),
Expand Down Expand Up @@ -115,6 +116,14 @@ def test_string_api(server):
assert c.hello("world") == "hello world"


def test_required_argument(server):
with client() as c:
with pytest.raises(TApplicationException):
c.hello()

assert c.hello(name="") == "hello "


def test_huge_res(server):
with client() as c:
big_str = "world" * 100000
Expand Down
8 changes: 8 additions & 0 deletions tests/test_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from thriftpy2._compat import PY3 # noqa
from thriftpy2.rpc import make_server, client_context # noqa
from thriftpy2.transport import TTransportException # noqa
from thriftpy2.thrift import TApplicationException # noqa


addressbook = thriftpy2.load(os.path.join(os.path.dirname(__file__),
Expand Down Expand Up @@ -158,6 +159,13 @@ def test_string_api(server):
assert c.hello("world") == "hello world"


def test_required_argument(server):
with client() as c:
assert c.hello("") == "hello "
with pytest.raises(TApplicationException):
c.hello()


def test_string_api_with_ssl(ssl_server):
with ssl_client() as c:
assert c.hello("world") == "hello world"
Expand Down
13 changes: 9 additions & 4 deletions thriftpy2/contrib/aio/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
import asyncio
import functools
from thriftpy2.thrift import args2kwargs
from thriftpy2.thrift import args_to_kwargs
from thriftpy2.thrift import TApplicationException, TMessageType


Expand All @@ -26,9 +26,14 @@ def __dir__(self):

@asyncio.coroutine
def _req(self, _api, *args, **kwargs):
_kw = args2kwargs(getattr(self._service, _api + "_args").thrift_spec,
*args)
kwargs.update(_kw)
try:
kwargs = args_to_kwargs(getattr(self._service, _api + "_args").thrift_spec,
*args, **kwargs)
except ValueError as e:
raise TApplicationException(
TApplicationException.UNKNOWN_METHOD,
'missing required argument {arg} for {service}.{api}'.format(
arg=e.args[0], service=self._service.__name__, api=_api))
result_cls = getattr(self._service, _api + "_result")

yield from self._send(_api, **kwargs)
Expand Down
30 changes: 23 additions & 7 deletions thriftpy2/thrift.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,22 @@
import linecache
import types

from ._compat import with_metaclass
from ._compat import with_metaclass, PY3
if PY3:
from itertools import zip_longest
else:
from itertools import izip_longest as zip_longest


def args2kwargs(thrift_spec, *args):
arg_names = [item[1][1] for item in sorted(thrift_spec.items())]
return dict(zip(arg_names, args))
def args_to_kwargs(thrift_spec, *args, **kwargs):
for item, value in zip_longest(sorted(thrift_spec.items()), args):
arg_name = item[1][1]
required = item[1][-1]
if value is not None:
kwargs[item[1][1]] = value
if required and arg_name not in kwargs:
raise ValueError(arg_name)
return kwargs


def parse_spec(ttype, spec=None):
Expand Down Expand Up @@ -192,9 +202,15 @@ def __dir__(self):
return self._service.thrift_services

def _req(self, _api, *args, **kwargs):
_kw = args2kwargs(getattr(self._service, _api + "_args").thrift_spec,
*args)
kwargs.update(_kw)
try:
kwargs = args_to_kwargs(getattr(self._service, _api + "_args").thrift_spec,
*args, **kwargs)
except ValueError as e:
raise TApplicationException(
TApplicationException.UNKNOWN_METHOD,
'{arg} is required argument for {service}.{api}'.format(
arg=e.args[0], service=self._service.__name__, api=_api))

result_cls = getattr(self._service, _api + "_result")

self._send(_api, **kwargs)
Expand Down