Skip to content

Commit

Permalink
Added typing in tests (#694)
Browse files Browse the repository at this point in the history
* reverted overloads for activate
* Fix matchers typing
* Added status and body attributes to BaseResponse
  • Loading branch information
beliaev-maksim authored Nov 14, 2023
1 parent c46cb06 commit 939562f
Show file tree
Hide file tree
Showing 9 changed files with 63 additions and 35 deletions.
2 changes: 1 addition & 1 deletion CHANGES
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
0.24.1
------

* Reintroduced overloads for better typing in `CallList`.
* Reverted overloads removal
* Added typing to `Call` attributes.


Expand Down
6 changes: 5 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
[mypy]
exclude = tests
show_column_numbers=True
show_error_codes = True

Expand Down Expand Up @@ -27,3 +26,8 @@ warn_unreachable=False

strict_equality=True
ignore_missing_imports=True

[mypy-responses.tests.*]
disallow_untyped_calls=False
disallow_untyped_defs=False
disable_error_code = union-attr
25 changes: 23 additions & 2 deletions responses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,9 @@ def __init__(
self._calls: CallList = CallList()
self.passthrough = passthrough

self.status: int = 200
self.body: "_Body" = ""

def __eq__(self, other: Any) -> bool:
if not isinstance(other, BaseResponse):
return False
Expand Down Expand Up @@ -569,6 +572,8 @@ def __init__(
auto_calculate_content_length: bool = False,
**kwargs: Any,
) -> None:
super().__init__(method, url, **kwargs)

# if we were passed a `json` argument,
# override the body and content_type
if json is not None:
Expand Down Expand Up @@ -596,7 +601,6 @@ def __init__(
self.stream: Optional[bool] = stream
self.content_type: str = content_type # type: ignore[assignment]
self.auto_calculate_content_length: bool = auto_calculate_content_length
super().__init__(method, url, **kwargs)

def get_response(self, request: "PreparedRequest") -> HTTPResponse:
if self.body and isinstance(self.body, Exception):
Expand Down Expand Up @@ -641,6 +645,8 @@ def __init__(
content_type: Optional[str] = "text/plain",
**kwargs: Any,
) -> None:
super().__init__(method, url, **kwargs)

self.callback = callback

if stream is not None:
Expand All @@ -650,7 +656,6 @@ def __init__(
)
self.stream: Optional[bool] = stream
self.content_type: Optional[str] = content_type
super().__init__(method, url, **kwargs)

def get_response(self, request: "PreparedRequest") -> HTTPResponse:
headers = self.get_headers()
Expand Down Expand Up @@ -970,6 +975,22 @@ def __exit__(self, type: Any, value: Any, traceback: Any) -> bool:
self.reset()
return success

@overload
def activate(self, func: "_F" = ...) -> "_F":
"""Overload for scenario when 'responses.activate' is used."""

@overload
def activate( # type: ignore[misc]
self,
*,
registry: Type[Any] = ...,
assert_all_requests_are_fired: bool = ...,
) -> Callable[["_F"], "_F"]:
"""Overload for scenario when
'responses.activate(registry=, assert_all_requests_are_fired=True)' is used.
See https://github.com/getsentry/responses/pull/469 for more details
"""

def activate(
self,
func: Optional["_F"] = None,
Expand Down
4 changes: 2 additions & 2 deletions responses/_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def _dump(
"response": {
"method": rsp.method,
"url": rsp.url,
"body": rsp.body, # type: ignore[attr-defined]
"status": rsp.status, # type: ignore[attr-defined]
"body": rsp.body,
"status": rsp.status,
"headers": rsp.headers,
"content_type": rsp.content_type,
"auto_calculate_content_length": content_length,
Expand Down
3 changes: 2 additions & 1 deletion responses/matchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Dict
from typing import List
from typing import Optional
from typing import Pattern
from typing import Tuple
from typing import Union
from urllib.parse import parse_qsl
Expand Down Expand Up @@ -391,7 +392,7 @@ def match(request: PreparedRequest) -> Tuple[bool, str]:


def header_matcher(
headers: Dict[str, str], strict_match: bool = False
headers: Dict[str, Union[str, Pattern[str]]], strict_match: bool = False
) -> Callable[..., Any]:
"""
Matcher to match 'headers' argument in request using the responses library.
Expand Down
8 changes: 5 additions & 3 deletions responses/tests/test_matchers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import gzip
import re
from typing import Any
from typing import List
from unittest.mock import Mock

import pytest
Expand Down Expand Up @@ -152,7 +154,7 @@ def test_json_params_matcher_json_list():


def test_json_params_matcher_json_list_empty():
json_a = []
json_a: "List[Any]" = []
json_b = "[]"
mock_request = Mock(body=json_b)
result = matchers.json_params_matcher(json_a)(mock_request)
Expand Down Expand Up @@ -457,7 +459,7 @@ def run():
(b"\xacHello World!", b"\xacHello World!"),
],
)
def test_multipart_matcher(req_file, match_file):
def test_multipart_matcher(req_file, match_file): # type: ignore[misc]
@responses.activate
def run():
req_data = {"some": "other", "data": "fields"}
Expand Down Expand Up @@ -796,7 +798,7 @@ def test_matchers_create_key_val_str():

class TestHeaderWithRegex:
@property
def url(self):
def url(self): # type: ignore[misc]
return "http://example.com/"

def _register(self):
Expand Down
2 changes: 1 addition & 1 deletion responses/tests/test_multithreading.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


@pytest.mark.parametrize("execution_number", range(10))
def test_multithreading_lock(execution_number):
def test_multithreading_lock(execution_number): # type: ignore[misc]
"""Reruns test multiple times since error is random and
depends on CPU and can lead to false positive result.
Expand Down
12 changes: 6 additions & 6 deletions responses/tests/test_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ def test_recorder_toml(self, httpserver):

def dump_to_file(file_path, registered):
with open(file_path, "wb") as file:
_dump(registered, file, tomli_w.dump)
_dump(registered, file, tomli_w.dump) # type: ignore[arg-type]

custom_recorder.dump_to_file = dump_to_file
custom_recorder.dump_to_file = dump_to_file # type: ignore[method-assign]

url202, url400, url404, url500 = self.prepare_server(httpserver)

Expand Down Expand Up @@ -151,25 +151,25 @@ def teardown_method(self):
assert not self.out_file.exists()

@pytest.mark.parametrize("parser", (yaml, tomli_w))
def test_add_from_file(self, parser):
def test_add_from_file(self, parser): # type: ignore[misc]
if parser == yaml:
with open(self.out_file, "w") as file:
parser.dump(get_data("example.com", "8080"), file)
else:
with open(self.out_file, "wb") as file:
with open(self.out_file, "wb") as file: # type: ignore[assignment]
parser.dump(get_data("example.com", "8080"), file)

@responses.activate
def run():
responses.patch("http://httpbin.org")
if parser == tomli_w:

def _parse_response_file(file_path):
def _parse_resp_f(file_path):
with open(file_path, "rb") as file:
data = _toml.load(file)
return data

responses.mock._parse_response_file = _parse_response_file
responses.mock._parse_response_file = _parse_resp_f # type: ignore[method-assign]

responses._add_from_file(file_path=self.out_file)
responses.post("http://httpbin.org/form")
Expand Down
36 changes: 18 additions & 18 deletions responses/tests/test_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def run():
),
],
)
def test_replace(original, replacement):
def test_replace(original, replacement): # type: ignore[misc]
@responses.activate
def run():
responses.add(responses.GET, "http://example.com/one", body="test1")
Expand Down Expand Up @@ -157,7 +157,7 @@ def run():
(re.compile(r"http://example\.com/one"), "http://example.com/one"),
],
)
def test_replace_error(original, replacement):
def test_replace_error(original, replacement): # type: ignore[misc]
@responses.activate
def run():
responses.add(responses.GET, original)
Expand Down Expand Up @@ -203,7 +203,7 @@ def run():
),
],
)
def test_upsert_replace(original, replacement):
def test_upsert_replace(original, replacement): # type: ignore[misc]
@responses.activate
def run():
responses.add(responses.GET, "http://example.com/one", body="test1")
Expand Down Expand Up @@ -241,7 +241,7 @@ def run():
),
],
)
def test_upsert_add(original, replacement):
def test_upsert_add(original, replacement): # type: ignore[misc]
@responses.activate
def run():
responses.add(responses.GET, "http://example.com/one", body="test1")
Expand Down Expand Up @@ -299,7 +299,7 @@ def run():
),
],
)
def test_response_equality(args1, kwargs1, args2, kwargs2, expected):
def test_response_equality(args1, kwargs1, args2, kwargs2, expected): # type: ignore[misc]
o1 = BaseResponse(*args1, **kwargs1)
o2 = BaseResponse(*args2, **kwargs2)
assert (o1 == o2) is expected
Expand Down Expand Up @@ -847,7 +847,7 @@ def send(
"adapter_class",
(CustomAdapter, PositionalArgsAdapter, PositionalArgsIncompleteAdapter),
)
def test_custom_adapter(self, adapter_class):
def test_custom_adapter(self, adapter_class): # type: ignore[misc]
"""Test basic adapter implementation and that responses can patch them properly."""

@responses.activate
Expand Down Expand Up @@ -902,12 +902,12 @@ def test_function(a, b=None):


@pytest.fixture
def my_fruit():
def my_fruit(): # type: ignore[misc]
return "apple"


@pytest.fixture
def fruit_basket(my_fruit):
def fruit_basket(my_fruit): # type: ignore[misc]
return ["banana", my_fruit]


Expand All @@ -926,7 +926,7 @@ def test_function(self, my_fruit, fruit_basket):

def test_activate_mock_interaction():
@patch("sys.stdout")
def test_function(mock_stdout):
def test_function(mock_stdout): # type: ignore[misc]
return mock_stdout

decorated_test_function = responses.activate(test_function)
Expand Down Expand Up @@ -1045,7 +1045,7 @@ def run():

@pytest.mark.parametrize("request_stream", (True, False, None))
@pytest.mark.parametrize("responses_stream", (True, False, None))
def test_response_cookies_session(request_stream, responses_stream):
def test_response_cookies_session(request_stream, responses_stream): # type: ignore[misc]
@responses.activate
def run():
url = "https://example.com/path"
Expand Down Expand Up @@ -1384,13 +1384,13 @@ def run():
# Type errors here and on 1250 are ignored because the stubs for requests
# are off https://github.com/python/typeshed/blob/f8501d33c737482a829c6db557a0be26895c5941
# /stubs/requests/requests/packages/__init__.pyi#L1
original_init = getattr(urllib3.HTTPResponse, "__init__") # type: ignore
original_init = getattr(urllib3.HTTPResponse, "__init__")

def patched_init(self, *args, **kwargs):
kwargs["enforce_content_length"] = True
original_init(self, *args, **kwargs)

monkeypatch.setattr(urllib3.HTTPResponse, "__init__", patched_init) # type: ignore
monkeypatch.setattr(urllib3.HTTPResponse, "__init__", patched_init)

run()
assert_reset()
Expand Down Expand Up @@ -1919,7 +1919,7 @@ def test_custom_target(monkeypatch):
"http://example.com/other/path/",
),
)
def test_request_param(url):
def test_request_param(url): # type: ignore[misc]
@responses.activate
def run():
params = {"hello": "world", "example": "params"}
Expand Down Expand Up @@ -1962,7 +1962,7 @@ def run():
@pytest.mark.parametrize(
"url", ("http://example.com", "http://example.com?hello=world")
)
def test_assert_call_count(url):
def test_assert_call_count(url): # type: ignore[misc]
@responses.activate
def run():
responses.add(responses.GET, url)
Expand Down Expand Up @@ -2160,7 +2160,7 @@ def run():
),
],
)
def test_response_representations(response_params, expected_representation):
def test_response_representations(response_params, expected_representation): # type: ignore[misc]
response = Response(**response_params)

assert str(response) == expected_representation
Expand Down Expand Up @@ -2206,7 +2206,7 @@ def run():
("http://fizzbuzz/foo", "http://fizzbuzz/foo"),
],
)
def test_rfc_compliance(url, other_url):
def test_rfc_compliance(url, other_url): # type: ignore[misc]
@responses.activate
def run():
responses.add(method=responses.GET, url=url)
Expand Down Expand Up @@ -2314,7 +2314,7 @@ def run_classic():
run_not_strict()

@pytest.mark.parametrize("assert_fired", (True, False, None))
def test_nested_decorators(self, assert_fired):
def test_nested_decorators(self, assert_fired): # type: ignore[misc]
"""Validate if assert_all_requests_are_fired is applied from the correct function.
assert_all_requests_are_fired must be applied from the function
Expand Down Expand Up @@ -2607,7 +2607,7 @@ def run():
assert_reset()

@pytest.mark.parametrize("raise_on_status", (True, False))
def test_max_retries_exceed(self, raise_on_status):
def test_max_retries_exceed(self, raise_on_status): # type: ignore[misc]
@responses.activate(registry=registries.OrderedRegistry)
def run():
url = "https://example.com"
Expand Down

0 comments on commit 939562f

Please sign in to comment.