Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion tests/files_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import mocket
import pytest
import requests as python_requests
from local_test_server import uses_local_server


@pytest.fixture
Expand All @@ -20,7 +21,7 @@ def log_stream():

@pytest.fixture
def post_url():
return "https://httpbin.org/post"
return "http://127.0.0.1:5000/post"


@pytest.fixture
Expand Down Expand Up @@ -63,6 +64,7 @@ def get_actual_request_data(log_stream):
return boundary, content_length, actual_request_post


@uses_local_server
def test_post_file_as_data( # pylint: disable=unused-argument
requests, sock, log_stream, post_url, request_logging
):
Expand All @@ -85,6 +87,7 @@ def test_post_file_as_data( # pylint: disable=unused-argument
assert sent.endswith(actual_request_post)


@uses_local_server
def test_post_files_text( # pylint: disable=unused-argument
sock, requests, log_stream, post_url, request_logging
):
Expand Down Expand Up @@ -120,6 +123,7 @@ def test_post_files_text( # pylint: disable=unused-argument
assert sent.endswith(actual_request_post)


@uses_local_server
def test_post_files_file( # pylint: disable=unused-argument
sock, requests, log_stream, post_url, request_logging
):
Expand Down Expand Up @@ -164,6 +168,7 @@ def test_post_files_file( # pylint: disable=unused-argument
assert sent.endswith(actual_request_post)


@uses_local_server
def test_post_files_complex( # pylint: disable=unused-argument
sock, requests, log_stream, post_url, request_logging
):
Expand Down
36 changes: 36 additions & 0 deletions tests/local_test_server.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,47 @@
# SPDX-FileCopyrightText: 2025 Tim Cocks
#
# SPDX-License-Identifier: MIT
import functools
import json
import socketserver
import threading
import time
from http.server import SimpleHTTPRequestHandler


def uses_local_server(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
with ReusableAddressTCPServer(("127.0.0.1", 5000), LocalTestServerHandler) as server:
server_thread = threading.Thread(target=server.serve_forever)
server_thread.daemon = True
server_thread.start()
time.sleep(2) # Give the server some time to start

result = func(*args, **kwargs)

server.shutdown()
server.server_close()
return result

return wrapper


class ReusableAddressTCPServer(socketserver.TCPServer):
# Enable SO_REUSEADDR
allow_reuse_address = True


class LocalTestServerHandler(SimpleHTTPRequestHandler):
def do_POST(self):
if self.path == "/post":
resp_body = json.dumps({"url": "http://localhost:5000/post"}).encode("utf-8")
self.send_response(200)
self.send_header("Content-type", "application/json")
self.send_header("Content-Length", str(len(resp_body)))
self.end_headers()
self.wfile.write(resp_body)

def do_GET(self):
if self.path == "/get":
resp_body = json.dumps({"url": "http://localhost:5000/get"}).encode("utf-8")
Expand Down
57 changes: 34 additions & 23 deletions tests/real_call_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@

import adafruit_connection_manager
import pytest
from local_test_server import LocalTestServerHandler
from local_test_server import uses_local_server

import adafruit_requests


@uses_local_server
def test_gets():
path_index = 0
status_code_index = 1
Expand All @@ -28,25 +29,35 @@ def test_gets():
("status/204", 204, "", None),
]

with socketserver.TCPServer(("127.0.0.1", 5000), LocalTestServerHandler) as server:
server_thread = threading.Thread(target=server.serve_forever)
server_thread.daemon = True
server_thread.start()

time.sleep(2) # Give the server some time to start

for case in cases:
requests = adafruit_requests.Session(socket, ssl.create_default_context())
with requests.get(f"http://127.0.0.1:5000/{case[path_index]}") as response:
assert response.status_code == case[status_code_index]
if case[text_result_index] is not None:
assert response.text == case[text_result_index]
if case[json_keys_index] is not None:
for key, value in case[json_keys_index].items():
assert response.json()[key] == value

adafruit_connection_manager.connection_manager_close_all(release_references=True)

server.shutdown()
server.server_close()
time.sleep(2)
for case in cases:
requests = adafruit_requests.Session(socket, ssl.create_default_context())
with requests.get(f"http://127.0.0.1:5000/{case[path_index]}") as response:
assert response.status_code == case[status_code_index]
if case[text_result_index] is not None:
assert response.text == case[text_result_index]
if case[json_keys_index] is not None:
for key, value in case[json_keys_index].items():
assert response.json()[key] == value

adafruit_connection_manager.connection_manager_close_all(release_references=True)


@pytest.mark.parametrize(
("allow_redirects", "status_code"),
(
(True, 200),
(False, 301),
),
)
def test_http_to_https_redirect(allow_redirects, status_code):
url = "http://www.adafruit.com/api/quotes.php"
requests = adafruit_requests.Session(socket, ssl.create_default_context())
with requests.get(url, allow_redirects=allow_redirects) as response:
assert response.status_code == status_code


def test_https_direct():
url = "https://www.adafruit.com/api/quotes.php"
requests = adafruit_requests.Session(socket, ssl.create_default_context())
with requests.get(url) as response:
assert response.status_code == 200