diff --git a/sh_scrapy/extension.py b/sh_scrapy/extension.py index 00dc758..2f027da 100644 --- a/sh_scrapy/extension.py +++ b/sh_scrapy/extension.py @@ -1,5 +1,6 @@ import logging from contextlib import suppress +from warnings import warn from weakref import WeakKeyDictionary import scrapy @@ -9,7 +10,6 @@ from scrapy.exporters import PythonItemExporter from scrapy.http import Request from scrapy.utils.deprecate import create_deprecated_class -from scrapy.utils.request import request_fingerprint from sh_scrapy import hsref from sh_scrapy.crawl import ignore_warnings @@ -79,13 +79,41 @@ def spider_closed(self, spider, reason): """ -class HubstorageMiddleware(object): +class HubstorageMiddleware: - def __init__(self): + @classmethod + def from_crawler(cls, crawler): + try: + result = cls(crawler) + except TypeError: + warn( + ( + "Subclasses of HubstorageMiddleware must now accept a " + "crawler parameter in their __init__ method. This will " + "become an error in the future." + ), + DeprecationWarning, + ) + result = cls() + result._crawler = crawler + result._load_fingerprinter() + return result + + def _load_fingerprinter(self): + if hasattr(self._crawler, "request_fingerprinter"): + self._fingerprint = lambda request: self._crawler.request_fingerprinter.fingerprint(request).hex() + else: + from scrapy.utils.request import request_fingerprint + self._fingerprint = request_fingerprint + + def __init__(self, crawler=None): self._seen = WeakKeyDictionary() self.hsref = hsref.hsref self.pipe_writer = pipe_writer self.request_id_sequence = request_id_sequence + self._crawler = crawler + if crawler: + self._load_fingerprinter() def process_spider_input(self, response, spider): self.pipe_writer.write_request( @@ -95,7 +123,7 @@ def process_spider_input(self, response, spider): rs=len(response.body), duration=response.meta.get('download_latency', 0) * 1000, parent=response.meta.get(HS_PARENT_ID_KEY), - fp=request_fingerprint(response.request), + fp=self._fingerprint(response.request), ) self._seen[response] = next(self.request_id_sequence) diff --git a/tests/test_extension.py b/tests/test_extension.py index 067e3cc..71fab7d 100644 --- a/tests/test_extension.py +++ b/tests/test_extension.py @@ -3,11 +3,13 @@ import mock import pytest +import scrapy +from packaging import version +from pytest import warns from scrapy import Spider from scrapy.exporters import PythonItemExporter from scrapy.http import Request, Response from scrapy.item import Item -from scrapy.utils.request import request_fingerprint from scrapy.utils.test import get_crawler from sh_scrapy.extension import HubstorageExtension, HubstorageMiddleware @@ -92,7 +94,8 @@ def test_hs_ext_spider_closed(hs_ext): @pytest.fixture def hs_mware(monkeypatch): monkeypatch.setattr('sh_scrapy.extension.pipe_writer', mock.Mock()) - return HubstorageMiddleware() + crawler = get_crawler() + return HubstorageMiddleware.from_crawler(crawler) def test_hs_mware_init(hs_mware): @@ -106,9 +109,13 @@ def test_hs_mware_process_spider_input(hs_mware): hs_mware.process_spider_input(response, Spider('test')) assert hs_mware.pipe_writer.write_request.call_count == 1 args = hs_mware.pipe_writer.write_request.call_args[1] + if hasattr(hs_mware._crawler, "request_fingerprinter"): + fp = "1c735665b072000e11b0169081bce5bbaeac09a7" + else: + fp = "a001a1eb4537acdc8525edf1250065cab2657152" assert args == { 'duration': 0, - 'fp': request_fingerprint(response.request), + 'fp': fp, 'method': 'GET', 'parent': None, 'rs': 0, @@ -138,3 +145,40 @@ def test_hs_mware_process_spider_output_filter_request(hs_mware): # make sure that we update hsparent meta only for requests assert result[0].meta.get(HS_PARENT_ID_KEY) is None assert result[1].meta[HS_PARENT_ID_KEY] == 'riq' + + +@pytest.mark.skipif( + version.parse(scrapy.__version__) < version.parse("2.7"), + reason="Only Scrapy 2.7 and higher support centralized request fingerprints." +) +def test_custom_fingerprinter(monkeypatch): + monkeypatch.setattr('sh_scrapy.extension.pipe_writer', mock.Mock()) + + class CustomFingerprinter: + def fingerprint(self, request): + return b"foo" + + crawler = get_crawler(settings_dict={"REQUEST_FINGERPRINTER_CLASS": CustomFingerprinter}) + mw = HubstorageMiddleware.from_crawler(crawler) + + response = Response('http://resp-url') + response.request = Request('http://req-url') + mw.process_spider_input(response, Spider('test')) + assert mw.pipe_writer.write_request.call_args[1]["fp"] == b"foo".hex() + + +def test_subclassing(): + class CustomHubstorageMiddleware(HubstorageMiddleware): + def __init__(self): + super().__init__() + self.foo = "bar" + + crawler = get_crawler() + with warns( + DeprecationWarning, + match="must now accept a crawler parameter in their __init__ method", + ): + mw = CustomHubstorageMiddleware.from_crawler(crawler) + + assert mw.foo == "bar" + assert hasattr(mw, "_fingerprint") diff --git a/tox.ini b/tox.ini index e6b5829..add9095 100644 --- a/tox.ini +++ b/tox.ini @@ -1,6 +1,9 @@ # tox.ini [tox] envlist = py36-scrapy16, py +requires = + # https://github.com/pypa/virtualenv/issues/2550 + virtualenv<=20.21.1 [testenv] deps =