diff --git a/rollbar/__init__.py b/rollbar/__init__.py index 905fd9a5..d2c82edd 100644 --- a/rollbar/__init__.py +++ b/rollbar/__init__.py @@ -1372,7 +1372,7 @@ def _build_starlette_request_data(request): 'params': dict(request.path_params), } - if hasattr(request, '_form'): + if hasattr(request, '_form') and request._form is not None: request_data['POST'] = { k: v.filename if isinstance(v, UploadFile) else v for k, v in request._form.items() @@ -1772,4 +1772,8 @@ def _wsgi_extract_user_ip(environ): def _starlette_extract_user_ip(request): + if not hasattr(request, 'client'): + return _extract_user_ip_from_headers(request) + if not hasattr(request.client, 'host'): + return _extract_user_ip_from_headers(request) return request.client.host or _extract_user_ip_from_headers(request) diff --git a/rollbar/contrib/fastapi/utils.py b/rollbar/contrib/fastapi/utils.py index db4902ec..daff1b2d 100644 --- a/rollbar/contrib/fastapi/utils.py +++ b/rollbar/contrib/fastapi/utils.py @@ -96,17 +96,20 @@ def get_installed_middlewares(app): return middlewares -def has_bare_routing(app_or_router): - expected_app_routes = 4 - expected_router_routes = 0 - - if ( - isinstance(app_or_router, FastAPI) - and expected_app_routes != len(app_or_router.routes) - ) or ( - isinstance(app_or_router, APIRouter) - and expected_router_routes != len(app_or_router.routes) - ): +def has_bare_routing(app_or_router: FastAPI | APIRouter): + if not isinstance(app_or_router, (FastAPI, APIRouter)): + return False + + urls = [ + getattr(app_or_router, 'openapi_url', None), + getattr(app_or_router, 'docs_url', None), + getattr(app_or_router, 'redoc_url', None), + getattr(app_or_router, 'swagger_ui_oauth2_redirect_url', None), + ] + + for route in app_or_router.routes: + if route is None or route.path in urls: + continue return False return True diff --git a/rollbar/lib/transforms/__init__.py b/rollbar/lib/transforms/__init__.py index 3f1b8802..4923b9c3 100644 --- a/rollbar/lib/transforms/__init__.py +++ b/rollbar/lib/transforms/__init__.py @@ -34,6 +34,8 @@ def transform(obj, transforms, key=None, batch_transforms=False): transforms = [BatchedTransform(transforms)] for transform in transforms: + if not isinstance(transform, Transform): + continue obj = _transform(obj, transform, key=key) return obj diff --git a/rollbar/test/fastapi_tests/test_logger.py b/rollbar/test/fastapi_tests/test_logger.py index 49a4a8a6..99a8d9a2 100644 --- a/rollbar/test/fastapi_tests/test_logger.py +++ b/rollbar/test/fastapi_tests/test_logger.py @@ -38,6 +38,7 @@ def test_should_add_framework_version_to_payload(self, mock_send_payload, *mocks app = FastAPI() app.add_middleware(LoggerMiddleware) + app.build_middleware_stack() rollbar.report_exc_info() @@ -70,10 +71,10 @@ def test_should_store_current_request(self, store_current_request): 'client': ['testclient', 50000], 'headers': [ (b'host', b'testserver'), - (b'user-agent', b'testclient'), - (b'accept-encoding', b'gzip, deflate'), (b'accept', b'*/*'), + (b'accept-encoding', b'gzip, deflate'), (b'connection', b'keep-alive'), + (b'user-agent', b'testclient'), ], 'http_version': '1.1', 'method': 'GET', diff --git a/rollbar/test/fastapi_tests/test_middleware.py b/rollbar/test/fastapi_tests/test_middleware.py index c49336e4..023754a0 100644 --- a/rollbar/test/fastapi_tests/test_middleware.py +++ b/rollbar/test/fastapi_tests/test_middleware.py @@ -16,6 +16,7 @@ import rollbar from rollbar.lib._async import AsyncMock from rollbar.test import BaseTest +from rollbar.test.utils import get_public_attrs ALLOWED_PYTHON_VERSION = sys.version_info >= (3, 6) @@ -152,6 +153,7 @@ def test_should_add_framework_version_to_payload(self, mock_send_payload, *mocks app = FastAPI() app.add_middleware(ReporterMiddleware) + app.build_middleware_stack() rollbar.report_exc_info() @@ -272,10 +274,10 @@ def test_should_store_current_request(self, store_current_request): 'client': ['testclient', 50000], 'headers': [ (b'host', b'testserver'), - (b'user-agent', b'testclient'), - (b'accept-encoding', b'gzip, deflate'), (b'accept', b'*/*'), + (b'accept-encoding', b'gzip, deflate'), (b'connection', b'keep-alive'), + (b'user-agent', b'testclient'), ], 'http_version': '1.1', 'method': 'GET', @@ -324,7 +326,7 @@ def test_should_return_current_request(self): async def read_root(original_request: Request): request = get_current_request() - self.assertEqual(request, original_request) + self.assertEqual(get_public_attrs(request), get_public_attrs(original_request)) client = TestClient(app) client.get('/') diff --git a/rollbar/test/starlette_tests/test_logger.py b/rollbar/test/starlette_tests/test_logger.py index 3ed51e68..5981e5cb 100644 --- a/rollbar/test/starlette_tests/test_logger.py +++ b/rollbar/test/starlette_tests/test_logger.py @@ -38,6 +38,7 @@ def test_should_add_framework_version_to_payload(self, mock_send_payload, *mocks app = Starlette() app.add_middleware(LoggerMiddleware) + app.build_middleware_stack() rollbar.report_exc_info() @@ -67,10 +68,10 @@ def test_should_store_current_request(self, store_current_request): 'client': ['testclient', 50000], 'headers': [ (b'host', b'testserver'), - (b'user-agent', b'testclient'), - (b'accept-encoding', b'gzip, deflate'), (b'accept', b'*/*'), + (b'accept-encoding', b'gzip, deflate'), (b'connection', b'keep-alive'), + (b'user-agent', b'testclient'), ], 'http_version': '1.1', 'method': 'GET', diff --git a/rollbar/test/starlette_tests/test_middleware.py b/rollbar/test/starlette_tests/test_middleware.py index 7c9f6554..e7415af0 100644 --- a/rollbar/test/starlette_tests/test_middleware.py +++ b/rollbar/test/starlette_tests/test_middleware.py @@ -16,6 +16,7 @@ import rollbar from rollbar.lib._async import AsyncMock from rollbar.test import BaseTest +from rollbar.test.utils import get_public_attrs ALLOWED_PYTHON_VERSION = sys.version_info >= (3, 6) @@ -138,6 +139,7 @@ def test_should_add_framework_version_to_payload(self, mock_send_payload, *mocks app = Starlette() app.add_middleware(ReporterMiddleware) + app.build_middleware_stack() rollbar.report_exc_info() @@ -243,10 +245,10 @@ def test_should_store_current_request(self, store_current_request): 'client': ['testclient', 50000], 'headers': [ (b'host', b'testserver'), - (b'user-agent', b'testclient'), - (b'accept-encoding', b'gzip, deflate'), (b'accept', b'*/*'), + (b'accept-encoding', b'gzip, deflate'), (b'connection', b'keep-alive'), + (b'user-agent', b'testclient'), ], 'http_version': '1.1', 'method': 'GET', @@ -290,7 +292,7 @@ def test_should_return_current_request(self): async def root(original_request): request = get_current_request() - self.assertEqual(request, original_request) + self.assertEqual(get_public_attrs(request), get_public_attrs(original_request)) return PlainTextResponse('OK') diff --git a/rollbar/test/starlette_tests/test_requests.py b/rollbar/test/starlette_tests/test_requests.py index 75bacb1d..09b0c292 100644 --- a/rollbar/test/starlette_tests/test_requests.py +++ b/rollbar/test/starlette_tests/test_requests.py @@ -10,6 +10,7 @@ import unittest from rollbar.test import BaseTest +from rollbar.test.utils import get_public_attrs ALLOWED_PYTHON_VERSION = sys.version_info >= (3, 6) @@ -49,7 +50,7 @@ def test_should_accept_request_param(self): stored_request = store_current_request(request) - self.assertEqual(request, stored_request) + self.assertEqual(get_public_attrs(request), get_public_attrs(stored_request)) def test_should_accept_scope_param_if_http_type(self): from starlette.requests import Request @@ -81,7 +82,7 @@ def test_should_accept_scope_param_if_http_type(self): request = store_current_request(scope, receive) - self.assertEqual(request, expected_request) + self.assertEqual(get_public_attrs(request), get_public_attrs(expected_request)) def test_should_not_accept_scope_param_if_not_http_type(self): from rollbar.contrib.starlette.requests import store_current_request diff --git a/rollbar/test/test_rollbar.py b/rollbar/test/test_rollbar.py index 3107e9e4..35b69766 100644 --- a/rollbar/test/test_rollbar.py +++ b/rollbar/test/test_rollbar.py @@ -20,6 +20,7 @@ from rollbar.lib import string_types from rollbar.test import BaseTest +from rollbar.test.utils import get_public_attrs try: eval(""" @@ -173,6 +174,7 @@ def test_starlette_request_data_with_consumed_body(self): body = b'body body body' scope = { 'type': 'http', + 'client': ('127.0.0.1', 1453), 'headers': [ (b'content-type', b'text/html'), (b'content-length', str(len(body)).encode('latin-1')), @@ -410,7 +412,7 @@ def test_get_request_starlette_middleware(self): def root(starlette_request): current_request = rollbar.get_request() - self.assertEqual(current_request, starlette_request) + self.assertEqual(get_public_attrs(current_request), get_public_attrs(starlette_request)) return PlainTextResponse("bye bye") @@ -437,7 +439,7 @@ def test_get_request_starlette_logger(self): def root(starlette_request): current_request = rollbar.get_request() - self.assertEqual(current_request, starlette_request) + self.assertEqual(get_public_attrs(current_request), get_public_attrs(starlette_request)) return PlainTextResponse("bye bye") @@ -465,7 +467,7 @@ def test_get_request_fastapi_middleware(self): def root(param, fastapi_request: Request): current_request = rollbar.get_request() - self.assertEqual(current_request, fastapi_request) + self.assertEqual(get_public_attrs(current_request), get_public_attrs(fastapi_request)) root = fastapi_add_route_with_request_param( app, root, '/{param}', 'fastapi_request' @@ -492,7 +494,7 @@ def test_get_request_fastapi_logger(self): def root(fastapi_request: Request): current_request = rollbar.get_request() - self.assertEqual(current_request, fastapi_request) + self.assertEqual(get_public_attrs(current_request), get_public_attrs(fastapi_request)) root = fastapi_add_route_with_request_param( app, root, '/{param}', 'fastapi_request' @@ -523,7 +525,7 @@ def test_get_request_fastapi_router(self): def root(fastapi_request: Request): current_request = rollbar.get_request() - self.assertEqual(current_request, fastapi_request) + self.assertEqual(get_public_attrs(current_request), get_public_attrs(fastapi_request)) root = fastapi_add_route_with_request_param( app, root, '/{param}', 'fastapi_request' diff --git a/rollbar/test/utils.py b/rollbar/test/utils.py new file mode 100644 index 00000000..88294592 --- /dev/null +++ b/rollbar/test/utils.py @@ -0,0 +1,4 @@ +from collections.abc import Mapping + +def get_public_attrs(obj: Mapping) -> dict: + return {k: obj[k] for k in obj if not k.startswith('_')}