Skip to content

Commit

Permalink
updates from PR review
Browse files Browse the repository at this point in the history
  • Loading branch information
ermeaney committed Nov 17, 2020
1 parent fb790f6 commit 25620e0
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
6 changes: 3 additions & 3 deletions tests/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def client_with_custom_header_factory(timeout=3000):
return make_client(addressbook.AddressBookService,
url="http://127.0.0.1:6080",
timeout=timeout,
http_header_factory=CustomHeaderFactory)
http_header_factory=CustomHeaderFactory())


def client_without_url(timeout=3000):
Expand All @@ -174,7 +174,7 @@ def test_client_context_with_header_factory(server):


def test_client_context_custom_with_header_factory(server):
with client_context_with_header_factory() as c:
with client_context_with_custom_header_factory() as c:
assert c.hello("world") == "hello world"


Expand All @@ -185,7 +185,7 @@ def test_client_with_header_factory(server):


def test_client_with_custom_header_factory(server):
c = client_with_header_factory()
c = client_with_custom_header_factory()
assert c.hello("world") == "hello world"
c.close()

Expand Down
12 changes: 5 additions & 7 deletions thriftpy2/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,7 @@ def __init__(self, uri, timeout=None, ssl_context_factory=None, http_header_fact
self.path += '?%s' % parsed.query
self.__wbuf = BytesIO()
self.__http = None
if http_header_factory:
self._http_header_factory = http_header_factory
else:
self._http_header_factory = THttpHeaderFactory()
self._http_header_factory = http_header_factory or THttpHeaderFactory()
self.__timeout = None
if timeout:
self.setTimeout(timeout)
Expand Down Expand Up @@ -254,16 +251,17 @@ def flush(self):
self.__http.putheader('Host', self.host)
self.__http.putheader('Content-Type', 'application/x-thrift')
self.__http.putheader('Content-Length', str(len(data)))
if (not self._http_header_factory.get_headers() or
'User-Agent' not in self._http_header_factory.get_headers()):
custom_headers = self._http_header_factory.get_headers()
if (not custom_headers or
'User-Agent' not in custom_headers):
user_agent = 'Python/THttpClient'
script = os.path.basename(sys.argv[0])
if script:
user_agent = '%s (%s)' % (
user_agent, urllib.parse.quote(script))
self.__http.putheader('User-Agent', user_agent)

if self._http_header_factory.get_headers():
if custom_headers:
for key, val in self._http_header_factory.get_headers().items():
self.__http.putheader(key, val)

Expand Down

0 comments on commit 25620e0

Please sign in to comment.