From 3771917bd1e54c97ec168f7e5f0feef12010639c Mon Sep 17 00:00:00 2001 From: Gary Wilson Jr Date: Wed, 19 Apr 2023 11:16:33 -0500 Subject: [PATCH] Replaced direct config.path references and added tests for each plugin implementation. Added join_path utility function, with tests. --- spectree/config.py | 17 +++++++++--- spectree/plugins/falcon_plugin.py | 4 +-- spectree/plugins/flask_plugin.py | 11 ++++---- spectree/plugins/quart_plugin.py | 8 +++--- spectree/plugins/starlette_plugin.py | 4 +-- spectree/utils.py | 14 ++++++++++ tests/flask_imports/__init__.py | 7 ++++- tests/flask_imports/dry_plugin_flask.py | 20 +++++++++----- tests/quart_imports/__init__.py | 6 ++++- tests/quart_imports/dry_plugin_quart.py | 20 +++++++++----- tests/test_config.py | 30 +++++++++++++++++++++ tests/test_plugin_falcon.py | 18 +++++++++---- tests/test_plugin_falcon_asgi.py | 18 +++++++++---- tests/test_plugin_flask_blueprint.py | 35 ++++++++++++++++++++----- tests/test_plugin_quart.py | 3 +++ tests/test_plugin_starlette.py | 18 +++++++++---- tests/test_utils.py | 29 ++++++++++++++++++++ 17 files changed, 208 insertions(+), 54 deletions(-) diff --git a/spectree/config.py b/spectree/config.py index bc981d70..ee08c757 100644 --- a/spectree/config.py +++ b/spectree/config.py @@ -1,4 +1,3 @@ -import posixpath import warnings from enum import Enum from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Union @@ -7,6 +6,7 @@ from .models import SecurityScheme, Server from .page import DEFAULT_PAGE_TEMPLATES +from .utils import join_path # Fall back to a str field if email-validator is not installed. if TYPE_CHECKING: @@ -118,9 +118,18 @@ def convert_to_lower_case(cls, values: Mapping[str, Any]) -> Dict[str, Any]: @property def spec_url(self) -> str: - sep = posixpath.sep - parts = (sep, self.path.lstrip(sep), self.filename.lstrip(sep)) - return posixpath.join(*parts) + return self.join_doc_path(self.filename) + + @property + def doc_root(self) -> str: + return self.join_doc_path("") + + def join_doc_path(self, path: str) -> str: + """ + Return the documentation path constructed using the configured + self.path documentation prefix and the given path string. + """ + return f"/{join_path((self.path, path))}" def swagger_oauth2_config(self) -> Dict[str, str]: """ diff --git a/spectree/plugins/falcon_plugin.py b/spectree/plugins/falcon_plugin.py index 892661eb..9673e2e3 100644 --- a/spectree/plugins/falcon_plugin.py +++ b/spectree/plugins/falcon_plugin.py @@ -76,11 +76,11 @@ def register_route(self, app: Any): ) for ui in self.config.page_templates: self.app.add_route( - f"/{self.config.path}/{ui}", + self.config.join_doc_path(ui), self.DOC_PAGE_ROUTE_CLASS( self.config.page_templates[ui], spec_url=self.config.spec_url, - spec_path=self.config.path, + spec_path=self.config.doc_root, **self.config.swagger_oauth2_config(), ), ) diff --git a/spectree/plugins/flask_plugin.py b/spectree/plugins/flask_plugin.py index 95a9ca88..c6575191 100644 --- a/spectree/plugins/flask_plugin.py +++ b/spectree/plugins/flask_plugin.py @@ -16,8 +16,7 @@ class FlaskPlugin(BasePlugin): def find_routes(self): for rule in current_app.url_map.iter_rules(): if any( - str(rule).startswith(path) - for path in (f"/{self.config.path}", "/static") + str(rule).startswith(path) for path in (self.config.doc_root, "/static") ): continue if rule.endpoint.startswith("openapi"): @@ -252,13 +251,13 @@ def gen_doc_page(ui): return self.config.page_templates[ui].format( spec_url=spec_url, - spec_path=self.config.path, + spec_path=self.config.doc_root, **self.config.swagger_oauth2_config(), ) for ui in self.config.page_templates: app.add_url_rule( - rule=f"/{self.config.path}/{ui}/", + rule=f"{self.config.join_doc_path(ui)}/", endpoint=f"openapi_{self.config.path}_{ui.replace('.', '_')}", view_func=lambda ui=ui: gen_doc_page(ui), ) @@ -267,11 +266,11 @@ def gen_doc_page(ui): else: for ui in self.config.page_templates: app.add_url_rule( - rule=f"/{self.config.path}/{ui}/", + rule=f"{self.config.join_doc_path(ui)}/", endpoint=f"openapi_{self.config.path}_{ui}", view_func=lambda ui=ui: self.config.page_templates[ui].format( spec_url=self.config.spec_url, - spec_path=self.config.path, + spec_path=self.config.doc_root, **self.config.swagger_oauth2_config(), ), ) diff --git a/spectree/plugins/quart_plugin.py b/spectree/plugins/quart_plugin.py index 28c9fb16..97514811 100644 --- a/spectree/plugins/quart_plugin.py +++ b/spectree/plugins/quart_plugin.py @@ -264,13 +264,13 @@ def gen_doc_page(ui): return self.config.page_templates[ui].format( spec_url=spec_url, - spec_path=self.config.path, + spec_path=self.config.doc_root, **self.config.swagger_oauth2_config(), ) for ui in self.config.page_templates: app.add_url_rule( - f"/{self.config.path}/{ui}/", + f"{self.config.join_doc_path(ui)}/", endpoint=f"openapi_{self.config.path}_{ui.replace('.', '_')}", view_func=lambda ui=ui: gen_doc_page(ui), ) @@ -279,11 +279,11 @@ def gen_doc_page(ui): else: for ui in self.config.page_templates: app.add_url_rule( - f"/{self.config.path}/{ui}/", + f"{self.config.join_doc_path(ui)}/", endpoint=f"openapi_{self.config.path}_{ui}", view_func=lambda ui=ui: self.config.page_templates[ui].format( spec_url=self.config.spec_url, - spec_path=self.config.path, + spec_path=self.config.doc_root, **self.config.swagger_oauth2_config(), ), ) diff --git a/spectree/plugins/starlette_plugin.py b/spectree/plugins/starlette_plugin.py index f0ef2226..568f23e2 100644 --- a/spectree/plugins/starlette_plugin.py +++ b/spectree/plugins/starlette_plugin.py @@ -45,11 +45,11 @@ def register_route(self, app): for ui in self.config.page_templates: self.app.add_route( - f"/{self.config.path}/{ui}", + self.config.join_doc_path(ui), lambda request, ui=ui: HTMLResponse( self.config.page_templates[ui].format( spec_url=self.config.spec_url, - spec_path=self.config.path, + spec_path=self.config.doc_root, **self.config.swagger_oauth2_config(), ) ), diff --git a/spectree/utils.py b/spectree/utils.py index 52af7c74..c7e0f49d 100644 --- a/spectree/utils.py +++ b/spectree/utils.py @@ -6,6 +6,7 @@ Any, Callable, Dict, + Iterable, Iterator, List, Mapping, @@ -349,3 +350,16 @@ def parse_resp(func: Any, naming_strategy: NamingStrategy = get_model_key): responses = func.resp.generate_spec(naming_strategy) return responses + + +def join_path(paths: Iterable[str]) -> str: + """ + Join path parts together. + + Normalizes any leading or trailing slashes on the given path parts and avoids + behavior of urllib.parse.urljoin and posixpath.join where a path part with + leading slash overrides any path parts before it due to being treated as an + absolute path. + """ + normalized_paths = (x.strip("/") for x in paths if x) + return "/".join(x for x in normalized_paths if x) diff --git a/tests/flask_imports/__init__.py b/tests/flask_imports/__init__.py index 2a05aa54..8ff58a04 100644 --- a/tests/flask_imports/__init__.py +++ b/tests/flask_imports/__init__.py @@ -1,4 +1,9 @@ -from .dry_plugin_flask import ( +import pytest + +# Enable pytest assertion rewriting for dry module. Must come before module is imported. +pytest.register_assert_rewrite("tests.flask_imports.dry_plugin_flask") + +from .dry_plugin_flask import ( # noqa: E402 test_flask_doc, test_flask_list_json_request, test_flask_no_response, diff --git a/tests/flask_imports/dry_plugin_flask.py b/tests/flask_imports/dry_plugin_flask.py index 3144a4fc..e4b7514f 100644 --- a/tests/flask_imports/dry_plugin_flask.py +++ b/tests/flask_imports/dry_plugin_flask.py @@ -73,28 +73,36 @@ def test_flask_validation_error_response_status_code( @pytest.mark.parametrize( - "test_client_and_api, expected_doc_pages", + "test_client_and_api, doc_prefix, expected_doc_pages", [ - pytest.param({}, ["redoc", "swagger"], id="default-page-templates"), + pytest.param({}, "/apidoc", ["redoc", "swagger"], id="default-page-templates"), + pytest.param( + {"api_kwargs": {"path": ""}}, + "", + ["redoc", "swagger"], + id="default-page-templates-in-root-path", + ), pytest.param( {"api_kwargs": {"page_templates": {"custom_page": "{spec_url}"}}}, + "/apidoc", ["custom_page"], id="custom-page-templates", ), ], indirect=["test_client_and_api"], ) -def test_flask_doc(test_client_and_api, expected_doc_pages): +def test_flask_doc(test_client_and_api, doc_prefix, expected_doc_pages): client, api = test_client_and_api - resp = client.get("/apidoc/openapi.json") + resp = client.get(f"{doc_prefix}/openapi.json") + assert resp.status_code == 200 assert resp.json == api.spec for doc_page in expected_doc_pages: - resp = client.get(f"/apidoc/{doc_page}/") + resp = client.get(f"{doc_prefix}/{doc_page}/") assert resp.status_code == 200 - resp = client.get(f"/apidoc/{doc_page}") + resp = client.get(f"{doc_prefix}/{doc_page}") assert resp.status_code == 308 diff --git a/tests/quart_imports/__init__.py b/tests/quart_imports/__init__.py index 8056bbf8..40b43746 100644 --- a/tests/quart_imports/__init__.py +++ b/tests/quart_imports/__init__.py @@ -1,4 +1,8 @@ -from .dry_plugin_quart import ( +import pytest + +pytest.register_assert_rewrite("tests.quart_imports.dry_plugin_quart") + +from .dry_plugin_quart import ( # noqa: E402 test_quart_doc, test_quart_no_response, test_quart_return_model, diff --git a/tests/quart_imports/dry_plugin_quart.py b/tests/quart_imports/dry_plugin_quart.py index 2ddca24a..4051c9d4 100644 --- a/tests/quart_imports/dry_plugin_quart.py +++ b/tests/quart_imports/dry_plugin_quart.py @@ -79,28 +79,36 @@ def test_quart_validation_error_response_status_code( @pytest.mark.parametrize( - "test_client_and_api, expected_doc_pages", + "test_client_and_api, doc_prefix, expected_doc_pages", [ - pytest.param({}, ["redoc", "swagger"], id="default-page-templates"), + pytest.param({}, "/apidoc", ["redoc", "swagger"], id="default-page-templates"), + pytest.param( + {"api_kwargs": {"path": ""}}, + "", + ["redoc", "swagger"], + id="default-page-templates-in-root-path", + ), pytest.param( {"api_kwargs": {"page_templates": {"custom_page": "{spec_url}"}}}, + "/apidoc", ["custom_page"], id="custom-page-templates", ), ], indirect=["test_client_and_api"], ) -def test_quart_doc(test_client_and_api, expected_doc_pages): +def test_quart_doc(test_client_and_api, doc_prefix, expected_doc_pages): client, api = test_client_and_api - resp = asyncio.run(client.get("/apidoc/openapi.json")) + resp = asyncio.run(client.get(f"{doc_prefix}/openapi.json")) + assert resp.status_code == 200 assert asyncio.run(resp.json) == api.spec for doc_page in expected_doc_pages: - resp = asyncio.run(client.get(f"/apidoc/{doc_page}/")) + resp = asyncio.run(client.get(f"{doc_prefix}/{doc_page}/")) assert resp.status_code == 200 - resp = asyncio.run(client.get(f"/apidoc/{doc_page}")) + resp = asyncio.run(client.get(f"{doc_prefix}/{doc_page}")) assert resp.status_code == 308 diff --git a/tests/test_config.py b/tests/test_config.py index 208222cb..6aaadcc0 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -82,6 +82,36 @@ def test_config_spec_url_when_given_path_and_filename(path, filename): assert config.spec_url == "/prefix/openapi.json" +@pytest.mark.parametrize( + "config_path, test_path, expected_path", + [ + pytest.param("prefix", "", "/prefix", id="root"), + pytest.param( + "prefix", "swagger", "/prefix/swagger", id="test-path-no-trailing-slash" + ), + pytest.param( + "prefix", "swagger/", "/prefix/swagger", id="test-path-trailing-slash" + ), + ], +) +def test_config_join_doc_path(config_path, test_path, expected_path): + config = Configuration(path=config_path) + assert config.join_doc_path(test_path) == expected_path + + +@pytest.mark.parametrize( + "config_path, expected_doc_root_path", + [ + pytest.param("", "/", id="root"), + pytest.param("prefix", "/prefix", id="path-no-trailing-slash"), + pytest.param("prefix/", "/prefix", id="path-trailing-slash"), + ], +) +def test_config_doc_root(config_path, expected_doc_root_path): + config = Configuration(path=config_path) + assert config.doc_root == expected_doc_root_path + + @pytest.mark.parametrize(("secure_item"), SECURITY_SCHEMAS) def test_update_security_scheme(secure_item: Type[SecurityScheme]): # update and validate each schema type diff --git a/tests/test_plugin_falcon.py b/tests/test_plugin_falcon.py index e91a9324..d7d71b74 100644 --- a/tests/test_plugin_falcon.py +++ b/tests/test_plugin_falcon.py @@ -400,25 +400,33 @@ def test_validation_error_response_status_code( @pytest.mark.parametrize( - "test_client_and_api, expected_doc_pages", + "test_client_and_api, doc_prefix, expected_doc_pages", [ - pytest.param({}, ["redoc", "swagger"], id="default-page-templates"), + pytest.param({}, "/apidoc", ["redoc", "swagger"], id="default-page-templates"), + pytest.param( + {"api_kwargs": {"path": ""}}, + "", + ["redoc", "swagger"], + id="default-page-templates-in-root-path", + ), pytest.param( {"api_kwargs": {"page_templates": {"custom_page": "{spec_url}"}}}, + "/apidoc", ["custom_page"], id="custom-page-templates", ), ], indirect=["test_client_and_api"], ) -def test_falcon_doc(test_client_and_api, expected_doc_pages): +def test_falcon_doc(test_client_and_api, doc_prefix, expected_doc_pages): client, api = test_client_and_api - resp = client.simulate_get("/apidoc/openapi.json") + resp = client.simulate_get(f"{doc_prefix}/openapi.json") + assert resp.status_code == 200 assert resp.json == api.spec for doc_page in expected_doc_pages: - resp = client.simulate_get(f"/apidoc/{doc_page}") + resp = client.simulate_get(f"{doc_prefix}/{doc_page}") assert resp.status_code == 200 diff --git a/tests/test_plugin_falcon_asgi.py b/tests/test_plugin_falcon_asgi.py index 056555b5..a6c4459d 100644 --- a/tests/test_plugin_falcon_asgi.py +++ b/tests/test_plugin_falcon_asgi.py @@ -310,25 +310,33 @@ def test_validation_error_response_status_code( @pytest.mark.parametrize( - "test_client_and_api, expected_doc_pages", + "test_client_and_api, doc_prefix, expected_doc_pages", [ - pytest.param({}, ["redoc", "swagger"], id="default-page-templates"), + pytest.param({}, "/apidoc", ["redoc", "swagger"], id="default-page-templates"), + pytest.param( + {"api_kwargs": {"path": ""}}, + "", + ["redoc", "swagger"], + id="default-page-templates-in-root-path", + ), pytest.param( {"api_kwargs": {"page_templates": {"custom_page": "{spec_url}"}}}, + "/apidoc", ["custom_page"], id="custom-page-templates", ), ], indirect=["test_client_and_api"], ) -def test_falcon_doc(test_client_and_api, expected_doc_pages): +def test_falcon_doc(test_client_and_api, doc_prefix, expected_doc_pages): client, api = test_client_and_api - resp = client.simulate_get("/apidoc/openapi.json") + resp = client.simulate_get(f"{doc_prefix}/openapi.json") + assert resp.status_code == 200 assert resp.json == api.spec for doc_page in expected_doc_pages: - resp = client.simulate_get(f"/apidoc/{doc_page}") + resp = client.simulate_get(f"{doc_prefix}/{doc_page}") assert resp.status_code == 200 diff --git a/tests/test_plugin_flask_blueprint.py b/tests/test_plugin_flask_blueprint.py index 619b567a..98704cee 100644 --- a/tests/test_plugin_flask_blueprint.py +++ b/tests/test_plugin_flask_blueprint.py @@ -231,23 +231,44 @@ def ping(): @pytest.mark.parametrize( - ("test_client_and_api", "prefix"), + ("test_client_and_api", "prefix", "doc_prefix"), [ - ({"register_blueprint_kwargs": {}}, ""), - ({"register_blueprint_kwargs": {"url_prefix": "/prefix"}}, "/prefix"), + pytest.param({"register_blueprint_kwargs": {}}, "", "/apidoc", id="default"), + pytest.param( + {"register_blueprint_kwargs": {}, "api_kwargs": {"path": ""}}, + "", + "", + id="root-doc-prefix", + ), + pytest.param( + {"register_blueprint_kwargs": {"url_prefix": "/prefix"}}, + "/prefix", + "/apidoc", + id="blueprint-prefix", + ), + pytest.param( + { + "register_blueprint_kwargs": {"url_prefix": "/prefix"}, + "api_kwargs": {"path": ""}, + }, + "/prefix", + "", + id="blueprint-prefix-with-root-doc-prefix", + ), ], indirect=["test_client_and_api"], ) -def test_flask_doc_prefix(test_client_and_api, prefix): +def test_flask_doc_prefix(test_client_and_api, prefix, doc_prefix): client, api = test_client_and_api - resp = client.get(prefix + "/apidoc/openapi.json") + resp = client.get(f"{prefix}{doc_prefix}/openapi.json") + assert resp.status_code == 200 assert resp.json == api.spec - resp = client.get(prefix + "/apidoc/redoc/") + resp = client.get(f"{prefix}{doc_prefix}/redoc/") assert resp.status_code == 200 - resp = client.get(prefix + "/apidoc/swagger/") + resp = client.get(f"{prefix}{doc_prefix}/swagger/") assert resp.status_code == 200 assert get_paths(api.spec) == [ diff --git a/tests/test_plugin_quart.py b/tests/test_plugin_quart.py index f320c5cd..c22dffb2 100644 --- a/tests/test_plugin_quart.py +++ b/tests/test_plugin_quart.py @@ -18,6 +18,9 @@ api_tag, ) +# import tests to execute +from .quart_imports import * # NOQA + def before_handler(req, resp, err, _): if err: diff --git a/tests/test_plugin_starlette.py b/tests/test_plugin_starlette.py index ad71313c..fa1c8bb4 100644 --- a/tests/test_plugin_starlette.py +++ b/tests/test_plugin_starlette.py @@ -339,25 +339,33 @@ def test_validation_error_response_status_code( @pytest.mark.parametrize( - "test_client_and_api, expected_doc_pages", + "test_client_and_api, doc_prefix, expected_doc_pages", [ - pytest.param({}, ["redoc", "swagger"], id="default-page-templates"), + pytest.param({}, "/apidoc", ["redoc", "swagger"], id="default-page-templates"), + pytest.param( + {"api_kwargs": {"path": ""}}, + "", + ["redoc", "swagger"], + id="default-page-templates-in-root-path", + ), pytest.param( {"api_kwargs": {"page_templates": {"custom_page": "{spec_url}"}}}, + "/apidoc", ["custom_page"], id="custom-page-templates", ), ], indirect=["test_client_and_api"], ) -def test_starlette_doc(test_client_and_api, expected_doc_pages): +def test_starlette_doc(test_client_and_api, doc_prefix, expected_doc_pages): client, api = test_client_and_api - resp = client.get("/apidoc/openapi.json") + resp = client.get(f"{doc_prefix}/openapi.json") + assert resp.status_code == 200 assert resp.json() == api.spec for doc_page in expected_doc_pages: - resp = client.get(f"/apidoc/{doc_page}") + resp = client.get(f"{doc_prefix}/{doc_page}") assert resp.status_code == 200 diff --git a/tests/test_utils.py b/tests/test_utils.py index 802c0215..aa607692 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,9 +1,12 @@ +import itertools + import pytest from spectree.response import Response from spectree.spec import SpecTree from spectree.utils import ( has_model, + join_path, parse_code, parse_comments, parse_name, @@ -279,3 +282,29 @@ def test_parse_params_with_route_param_keywords(): "explode": True, }, ] + + +@pytest.mark.parametrize( + "paths, expected", + [ + pytest.param([""], "", id="empty-string"), + pytest.param(["", ""], "", id="multiple-empty-string"), + pytest.param([" "], " ", id="space"), + pytest.param(["/"], "", id="slash"), + pytest.param(["/", "/"], "", id="multiple-slash"), + pytest.param(["//"], "", id="double-slash"), + pytest.param(["dir1"], "dir1", id="single-path"), + pytest.param(["dir1", "dir2"], "dir1/dir2", id="multiple-paths"), + ] + + [ + pytest.param(paths, "dir1/dir2", id=f"leading-trailing-slashes-{num}") + for num, paths in enumerate( + itertools.product( + ["dir1", "/dir1", "dir1/", "/dir1/"], + ["dir2", "/dir2", "dir2/", "/dir2/"], + ) + ) + ], +) +def test_join_path(paths, expected): + assert join_path(paths) == expected