From a1df581908b7a0165dd232151647a23b5d6b4800 Mon Sep 17 00:00:00 2001 From: Luke Inman-Semerau Date: Thu, 12 Dec 2013 11:43:54 -0800 Subject: [PATCH] adding back old method of doing http requests for non keep-alive connections (I tried setting Connection:'close' that *should* work, but ghostdriver wasn't honoring it and didn't respond with the right header) Fixes Issue #6707 #6706 --- .../webdriver/remote/remote_connection.py | 175 +++++++++++++++--- 1 file changed, 145 insertions(+), 30 deletions(-) diff --git a/py/selenium/webdriver/remote/remote_connection.py b/py/selenium/webdriver/remote/remote_connection.py index 3cf67f52e9398..de12508f909a5 100644 --- a/py/selenium/webdriver/remote/remote_connection.py +++ b/py/selenium/webdriver/remote/remote_connection.py @@ -1,6 +1,4 @@ -# Copyright 2008-2009 WebDriver committers -# Copyright 2008-2009 Google Inc. -# Copyright 2013 BrowserStack +# Copyright 2008-2013 Software Freedom Conservancy # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,17 +19,11 @@ try: import http.client as httplib -except ImportError: - import httplib as httplib - -try: from urllib import request as url_request -except ImportError: - import urllib2 as url_request - -try: from urllib import parse -except ImportError: +except ImportError: # above is available in py3+, below is py2.7 + import httplib as httplib + import urllib2 as url_request import urlparse as parse from .command import Command @@ -41,6 +33,98 @@ LOGGER = logging.getLogger(__name__) +class Request(url_request.Request): + """ + Extends the url_request.Request to support all HTTP request types. + """ + + def __init__(self, url, data=None, method=None): + """ + Initialise a new HTTP request. + + :Args: + - url - String for the URL to send the request to. + - data - Data to send with the request. + """ + if method is None: + method = data is not None and 'POST' or 'GET' + elif method != 'POST' and method != 'PUT': + data = None + self._method = method + url_request.Request.__init__(self, url, data=data) + + def get_method(self): + """ + Returns the HTTP method used by this request. + """ + return self._method + + +class Response(object): + """ + Represents an HTTP response. + """ + + def __init__(self, fp, code, headers, url): + """ + Initialise a new Response. + + :Args: + - fp - The response body file object. + - code - The HTTP status code returned by the server. + - headers - A dictionary of headers returned by the server. + - url - URL of the retrieved resource represented by this Response. + """ + self.fp = fp + self.read = fp.read + self.code = code + self.headers = headers + self.url = url + + def close(self): + """ + Close the response body file object. + """ + self.read = None + self.fp = None + + def info(self): + """ + Returns the response headers. + """ + return self.headers + + def geturl(self): + """ + Returns the URL for the resource returned in this response. + """ + return self.url + + +class HttpErrorHandler(url_request.HTTPDefaultErrorHandler): + """ + A custom HTTP error handler. + + Used to return Response objects instead of raising an HTTPError exception. + """ + + def http_error_default(self, req, fp, code, msg, headers): + """ + Default HTTP error handler. + + :Args: + - req - The original Request object. + - fp - The response body file object. + - code - The HTTP status code returned by the server. + - msg - The HTTP status message returned by the server. + - headers - The response headers. + + :Returns: + A new Response object. + """ + return Response(fp, code, headers, req.get_full_url()) + + class RemoteConnection(object): """ A connection with the Remote WebDriver server. @@ -71,7 +155,8 @@ def __init__(self, remote_server_addr, keep_alive=False): LOGGER.info('Could not get IP address for host: %s' % parsed_url.hostname) self._url = remote_server_addr - self._conn = httplib.HTTPConnection(str(addr), str(parsed_url.port)) + if keep_alive: + self._conn = httplib.HTTPConnection(str(addr), str(parsed_url.port)) self._commands = { Command.STATUS: ('GET', '/status'), Command.NEW_SESSION: ('POST', '/session'), @@ -278,24 +363,54 @@ def _request(self, url, data=None, method=None): LOGGER.debug('%s %s %s' % (method, url, data)) parsed_url = parse.urlparse(url) - headers = {method: parsed_url.path, - "User-Agent": "Python http auth", - "Content-type": "application/json;charset=\"UTF-8\"", - "Accept": "application/json"} + if self.keep_alive: - headers['Connection'] = 'keep-alive' - - # for basic auth - if parsed_url.username: - auth = base64.standard_b64encode('%s:%s' % (parsed_url.username, parsed_url.password)).replace('\n', '') - # Authorization header - headers["Authorization"] = "Basic %s" % auth - - self._conn.request(method, parsed_url.path, data, headers) - resp = self._conn.getresponse() - statuscode = resp.status - statusmessage = resp.msg - LOGGER.debug('%s %s' % (statuscode, statusmessage)) + headers = {"Connection": 'keep-alive', method: parsed_url.path, + "User-Agent": "Python http auth", + "Content-type": "application/json;charset=\"UTF-8\"", + "Accept": "application/json"} + if parsed_url.username: + auth = base64.standard_b64encode('%s:%s' % + (parsed_url.username, parsed_url.password)).replace('\n', '') + headers["Authorization"] = "Basic %s" % auth + self._conn.request(method, parsed_url.path, data, headers) + resp = self._conn.getresponse() + statuscode = resp.status + else: + password_manager = None + if parsed_url.username: + netloc = parsed_url.hostname + if parsed_url.port: + netloc += ":%s" % parsed_url.port + cleaned_url = parse.urlunparse((parsed_url.scheme, + netloc, + parsed_url.path, + parsed_url.params, + parsed_url.query, + parsed_url.fragment)) + password_manager = url_request.HTTPPasswordMgrWithDefaultRealm() + password_manager.add_password(None, + "%s://%s" % (parsed_url.scheme, netloc), + parsed_url.username, + parsed_url.password) + request = Request(cleaned_url, data=data.encode('utf-8'), method=method) + else: + request = Request(url, data=data.encode('utf-8'), method=method) + + request.add_header('Accept', 'application/json') + request.add_header('Content-Type', 'application/json;charset=UTF-8') + + if password_manager: + opener = url_request.build_opener(url_request.HTTPRedirectHandler(), + HttpErrorHandler(), + url_request.HTTPBasicAuthHandler(password_manager)) + else: + opener = url_request.build_opener(url_request.HTTPRedirectHandler(), + HttpErrorHandler()) + resp = opener.open(request) + statuscode = resp.code + resp.getheader = lambda x: resp.headers.getheader(x) + data = resp.read() try: if 399 < statuscode < 500: