diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 25711c8..cca105d 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -7,7 +7,13 @@ Issues and pull requests are more than welcome: https://github.com/developmentse
```bash
$ git clone https://github.com/developmentseed/tifeatures.git
$ cd tifeatures
-$ pip install -e .[dev]
+$ pip install -e .["test,dev"]
+```
+
+You can then run the tests with the following command:
+
+```sh
+python -m pytest --cov tifeatures --cov-report term-missing
```
**pre-commit**
diff --git a/pyproject.toml b/pyproject.toml
index 03981d9..8584f5a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -21,7 +21,7 @@ dynamic = ["version"]
dependencies = [
"asyncpg>=0.23.0",
"buildpg>=0.3",
- "fastapi>=0.73",
+ "fastapi>=0.77",
"jinja2>=2.11.2,<4.0.0",
"geojson-pydantic",
"starlette-cramjam>=0.1.0,<0.2",
diff --git a/tests/conftest.py b/tests/conftest.py
index fff5e0f..81a7f1e 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -32,9 +32,16 @@ def database_url(test_db):
def app(database_url, monkeypatch):
"""Create app with connection to the pytest database."""
monkeypatch.setenv("DATABASE_URL", str(database_url))
+ monkeypatch.setenv(
+ "TIFEATURES_TEMPLATE_DIRECTORY", os.path.join(DATA_DIR, "templates")
+ )
from tifeatures.main import app
+ # Remove middlewares https://github.com/encode/starlette/issues/472
+ app.user_middleware = []
+ app.middleware_stack = app.build_middleware_stack()
+
# register functions to app.state.function_catalog here
with TestClient(app) as app:
diff --git a/tests/fixtures/templates/collections.html b/tests/fixtures/templates/collections.html
new file mode 100644
index 0000000..f0976c2
--- /dev/null
+++ b/tests/fixtures/templates/collections.html
@@ -0,0 +1,26 @@
+{% include "header.html" %}
+
+
Custom Collections
+
+
+
+
+
+ | Title |
+ Type |
+ Description |
+
+
+
+{% for collection in response.collections %}
+
+ | {{ collection.title or collection.id }} |
+ {{ collection.itemType }} |
+ {{ collection.description or collection.title or collection.id }} |
+
+{% endfor %}
+
+
+
+
+{% include "footer.html" %}
diff --git a/tests/routes/test_templates.py b/tests/routes/test_templates.py
new file mode 100644
index 0000000..099b679
--- /dev/null
+++ b/tests/routes/test_templates.py
@@ -0,0 +1,11 @@
+"""Test HTML templates."""
+
+
+def test_custom_templates(app):
+ """Test /collections endpoint."""
+ response = app.get("/collections")
+ assert response.status_code == 200
+
+ response = app.get("/collections?f=html")
+ assert response.status_code == 200
+ assert "Custom Collections" in response.text
diff --git a/tifeatures/factory.py b/tifeatures/factory.py
index 801801c..08e1578 100644
--- a/tifeatures/factory.py
+++ b/tifeatures/factory.py
@@ -1,10 +1,11 @@
"""tifeatures.factory: router factories."""
import json
-import pathlib
from dataclasses import dataclass, field
from typing import Any, Callable, List, Optional
+import jinja2
+
from tifeatures import model
from tifeatures.dependencies import (
CollectionParams,
@@ -19,20 +20,31 @@
from tifeatures.settings import APISettings
from fastapi import APIRouter, Depends, Path, Query
-from fastapi.templating import Jinja2Templates
from starlette.datastructures import QueryParams
from starlette.requests import Request
-from starlette.responses import HTMLResponse
+from starlette.templating import Jinja2Templates, _TemplateResponse
-template_dir = str(pathlib.Path(__file__).parent.joinpath("templates"))
-templates = Jinja2Templates(directory=template_dir)
settings = APISettings()
+# custom template directory
+templates_location: List[Any] = (
+ [jinja2.FileSystemLoader(settings.template_directory)]
+ if settings.template_directory
+ else []
+)
+# default template directory
+templates_location.append(jinja2.PackageLoader(__package__, "templates"))
+
+templates = Jinja2Templates(
+ directory="", # we need to set a dummy directory variable, see https://github.com/encode/starlette/issues/1214
+ loader=jinja2.ChoiceLoader(templates_location),
+)
+
def create_html_response(
request: Request, data: str, template_name: str
-) -> HTMLResponse:
+) -> _TemplateResponse:
"""Create Template response."""
urlpath = request.url.path
crumbs = []
diff --git a/tifeatures/settings.py b/tifeatures/settings.py
index 1af6e56..a21ef0f 100644
--- a/tifeatures/settings.py
+++ b/tifeatures/settings.py
@@ -12,6 +12,7 @@ class _APISettings(pydantic.BaseSettings):
name: str = "TiFeatures"
cors_origins: str = "*"
cachecontrol: str = "public, max-age=3600"
+ template_directory: Optional[str] = None
@pydantic.validator("cors_origins")
def parse_cors_origin(cls, v):