Skip to content

Commit

Permalink
adding back old method of doing http requests for non keep-alive conn…
Browse files Browse the repository at this point in the history
…ections

(I tried setting Connection:'close' that *should* work, but ghostdriver wasn't honoring it and didn't respond with the right header)

Fixes Issue #6707 #6706
  • Loading branch information
lukeis committed Dec 12, 2013
1 parent 79f5c19 commit a1df581
Showing 1 changed file with 145 additions and 30 deletions.
175 changes: 145 additions & 30 deletions py/selenium/webdriver/remote/remote_connection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
# Copyright 2008-2009 WebDriver committers
# Copyright 2008-2009 Google Inc.
# Copyright 2013 BrowserStack
# Copyright 2008-2013 Software Freedom Conservancy
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -21,17 +19,11 @@

try:
import http.client as httplib
except ImportError:
import httplib as httplib

try:
from urllib import request as url_request
except ImportError:
import urllib2 as url_request

try:
from urllib import parse
except ImportError:
except ImportError: # above is available in py3+, below is py2.7
import httplib as httplib
import urllib2 as url_request
import urlparse as parse

from .command import Command
Expand All @@ -41,6 +33,98 @@
LOGGER = logging.getLogger(__name__)


class Request(url_request.Request):
"""
Extends the url_request.Request to support all HTTP request types.
"""

def __init__(self, url, data=None, method=None):
"""
Initialise a new HTTP request.
:Args:
- url - String for the URL to send the request to.
- data - Data to send with the request.
"""
if method is None:
method = data is not None and 'POST' or 'GET'
elif method != 'POST' and method != 'PUT':
data = None
self._method = method
url_request.Request.__init__(self, url, data=data)

def get_method(self):
"""
Returns the HTTP method used by this request.
"""
return self._method


class Response(object):
"""
Represents an HTTP response.
"""

def __init__(self, fp, code, headers, url):
"""
Initialise a new Response.
:Args:
- fp - The response body file object.
- code - The HTTP status code returned by the server.
- headers - A dictionary of headers returned by the server.
- url - URL of the retrieved resource represented by this Response.
"""
self.fp = fp
self.read = fp.read
self.code = code
self.headers = headers
self.url = url

def close(self):
"""
Close the response body file object.
"""
self.read = None
self.fp = None

def info(self):
"""
Returns the response headers.
"""
return self.headers

def geturl(self):
"""
Returns the URL for the resource returned in this response.
"""
return self.url


class HttpErrorHandler(url_request.HTTPDefaultErrorHandler):
"""
A custom HTTP error handler.
Used to return Response objects instead of raising an HTTPError exception.
"""

def http_error_default(self, req, fp, code, msg, headers):
"""
Default HTTP error handler.
:Args:
- req - The original Request object.
- fp - The response body file object.
- code - The HTTP status code returned by the server.
- msg - The HTTP status message returned by the server.
- headers - The response headers.
:Returns:
A new Response object.
"""
return Response(fp, code, headers, req.get_full_url())


class RemoteConnection(object):
"""
A connection with the Remote WebDriver server.
Expand Down Expand Up @@ -71,7 +155,8 @@ def __init__(self, remote_server_addr, keep_alive=False):
LOGGER.info('Could not get IP address for host: %s' % parsed_url.hostname)

self._url = remote_server_addr
self._conn = httplib.HTTPConnection(str(addr), str(parsed_url.port))
if keep_alive:
self._conn = httplib.HTTPConnection(str(addr), str(parsed_url.port))
self._commands = {
Command.STATUS: ('GET', '/status'),
Command.NEW_SESSION: ('POST', '/session'),
Expand Down Expand Up @@ -278,24 +363,54 @@ def _request(self, url, data=None, method=None):
LOGGER.debug('%s %s %s' % (method, url, data))

parsed_url = parse.urlparse(url)
headers = {method: parsed_url.path,
"User-Agent": "Python http auth",
"Content-type": "application/json;charset=\"UTF-8\"",
"Accept": "application/json"}

if self.keep_alive:
headers['Connection'] = 'keep-alive'

# for basic auth
if parsed_url.username:
auth = base64.standard_b64encode('%s:%s' % (parsed_url.username, parsed_url.password)).replace('\n', '')
# Authorization header
headers["Authorization"] = "Basic %s" % auth

self._conn.request(method, parsed_url.path, data, headers)
resp = self._conn.getresponse()
statuscode = resp.status
statusmessage = resp.msg
LOGGER.debug('%s %s' % (statuscode, statusmessage))
headers = {"Connection": 'keep-alive', method: parsed_url.path,
"User-Agent": "Python http auth",
"Content-type": "application/json;charset=\"UTF-8\"",
"Accept": "application/json"}
if parsed_url.username:
auth = base64.standard_b64encode('%s:%s' %
(parsed_url.username, parsed_url.password)).replace('\n', '')
headers["Authorization"] = "Basic %s" % auth
self._conn.request(method, parsed_url.path, data, headers)
resp = self._conn.getresponse()
statuscode = resp.status
else:
password_manager = None
if parsed_url.username:
netloc = parsed_url.hostname
if parsed_url.port:
netloc += ":%s" % parsed_url.port
cleaned_url = parse.urlunparse((parsed_url.scheme,
netloc,
parsed_url.path,
parsed_url.params,
parsed_url.query,
parsed_url.fragment))
password_manager = url_request.HTTPPasswordMgrWithDefaultRealm()
password_manager.add_password(None,
"%s://%s" % (parsed_url.scheme, netloc),
parsed_url.username,
parsed_url.password)
request = Request(cleaned_url, data=data.encode('utf-8'), method=method)
else:
request = Request(url, data=data.encode('utf-8'), method=method)

request.add_header('Accept', 'application/json')
request.add_header('Content-Type', 'application/json;charset=UTF-8')

if password_manager:
opener = url_request.build_opener(url_request.HTTPRedirectHandler(),
HttpErrorHandler(),
url_request.HTTPBasicAuthHandler(password_manager))
else:
opener = url_request.build_opener(url_request.HTTPRedirectHandler(),
HttpErrorHandler())
resp = opener.open(request)
statuscode = resp.code
resp.getheader = lambda x: resp.headers.getheader(x)

data = resp.read()
try:
if 399 < statuscode < 500:
Expand Down

0 comments on commit a1df581

Please sign in to comment.