Skip to content

Commit 0885f0b

Browse files
move back to attrs (#729)
* move back to attrs * update changelog * edit tests * more doc
1 parent 494e485 commit 0885f0b

File tree

13 files changed

+136
-157
lines changed

13 files changed

+136
-157
lines changed

CHANGES.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22

33
## [Unreleased] - TBD
44

5+
## [3.0.0b2] - 2024-07-09
6+
7+
### Changed
8+
9+
* move back to `@attrs` (instead of dataclass) for `APIRequest` (model for GET request) class type [#729](https://github.com/stac-utils/stac-fastapi/pull/729)
10+
511
## [3.0.0b1] - 2024-07-05
612

713
### Added
@@ -432,7 +438,8 @@
432438

433439
* First PyPi release!
434440

435-
[Unreleased]: <https://github.com/stac-utils/stac-fastapi/compare/3.0.0b1..main>
441+
[Unreleased]: <https://github.com/stac-utils/stac-fastapi/compare/3.0.0b2..main>
442+
[3.0.0b2]: <https://github.com/stac-utils/stac-fastapi/compare/3.0.0b1..3.0.0b2>
436443
[3.0.0b1]: <https://github.com/stac-utils/stac-fastapi/compare/3.0.0a4..3.0.0b1>
437444
[3.0.0a4]: <https://github.com/stac-utils/stac-fastapi/compare/3.0.0a3..3.0.0a4>
438445
[3.0.0a3]: <https://github.com/stac-utils/stac-fastapi/compare/3.0.0a2..3.0.0a3>

docs/src/migrations/v3.0.0.md

Lines changed: 33 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -23,49 +23,6 @@ In addition to pydantic v2 update, `stac-pydantic` has been updated to better ma
2323

2424
* `PostFieldsExtension.filter_fields` property has been removed.
2525

26-
## `attr` -> `dataclass` for APIRequest models
27-
28-
Models for **GET** requests, defining the path and query parameters, now uses python `dataclass` instead of `attr`.
29-
30-
```python
31-
# before
32-
@attr.s
33-
class CollectionModel(APIRequest):
34-
collections: Optional[str] = attr.ib(default=None, converter=str2list)
35-
36-
# now
37-
@dataclass
38-
class CollectionModel(APIRequest):
39-
collections: Annotated[Optional[str], Query()] = None
40-
41-
def __post_init__(self):
42-
"""convert attributes."""
43-
if self.collections:
44-
self.collections = str2list(self.collections) # type: ignore
45-
46-
```
47-
48-
!!! warning
49-
50-
if you want to extend a class with a `required` attribute (without default), you will have to write all the attributes to avoid having *non-default* attributes defined after *default* attributes (ref: https://github.com/stac-utils/stac-fastapi/pull/714/files#r1651557338)
51-
52-
```python
53-
@dataclass
54-
class A:
55-
value: Annotated[str, Query()]
56-
57-
# THIS WON'T WORK
58-
@dataclass
59-
class B(A):
60-
another_value: Annotated[str, Query(...)]
61-
62-
# DO THIS
63-
@dataclass
64-
class B(A):
65-
another_value: Annotated[str, Query(...)]
66-
value: Annotated[str, Query()]
67-
```
68-
6926
## Middlewares configuration
7027

7128
The `StacApi.middlewares` attribute has been updated to accept a list of `starlette.middleware.Middleware`. This enables dynamic configuration of middlewares (see https://github.com/stac-utils/stac-fastapi/pull/442).
@@ -113,9 +70,9 @@ stac = StacApi(
11370
)
11471

11572
# now
116-
@dataclass
73+
@attr.s
11774
class CollectionsRequest(APIRequest):
118-
user: str = Query(...)
75+
user: Annotated[str, Query(...)] = attr.ib()
11976

12077
stac = StacApi(
12178
search_get_request_model=getSearchModel,
@@ -127,6 +84,37 @@ stac = StacApi(
12784
)
12885
```
12986

87+
## APIRequest - GET Request Model
88+
89+
Most of the **GET** endpoints are configured with `stac_fastapi.types.search.APIRequest` base class.
90+
91+
e.g the BaseSearchGetRequest, default for the `GET - /search` endpoint:
92+
93+
```python
94+
@attr.s
95+
class BaseSearchGetRequest(APIRequest):
96+
"""Base arguments for GET Request."""
97+
98+
collections: Annotated[Optional[str], Query()] = attr.ib(
99+
default=None, converter=str2list
100+
)
101+
ids: Annotated[Optional[str], Query()] = attr.ib(default=None, converter=str2list)
102+
bbox: Annotated[Optional[BBox], Query()] = attr.ib(default=None, converter=str2bbox)
103+
intersects: Annotated[Optional[str], Query()] = attr.ib(default=None)
104+
datetime: Annotated[Optional[DateTimeType], Query()] = attr.ib(
105+
default=None, converter=str_to_interval
106+
)
107+
limit: Annotated[Optional[int], Query()] = attr.ib(default=10)
108+
```
109+
110+
We use [*python attrs*](https://www.attrs.org/en/stable/) to construct those classes. **Type Hint** for each attribute is important and should be defined using `Annotated[{type}, fastapi.Query()]` form.
111+
112+
```python
113+
@attr.s
114+
class SomeRequest(APIRequest):
115+
user_number: Annotated[Optional[int], Query(alias="user-number")] = attr.ib(default=None)
116+
```
117+
130118
## Filter extension
131119

132120
`default_includes` attribute has been removed from the `ApiSettings` object. If you need `defaults` includes you can overwrite the `FieldExtension` models (see https://github.com/stac-utils/stac-fastapi/pull/706).

stac_fastapi/api/stac_fastapi/api/models.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""Api request/response models."""
22

3-
from dataclasses import dataclass, make_dataclass
43
from typing import List, Optional, Type, Union
54

5+
import attr
66
from fastapi import Path, Query
77
from pydantic import BaseModel, create_model
88
from stac_pydantic.shared import BBox
@@ -43,11 +43,11 @@ def create_request_model(
4343

4444
mixins = mixins or []
4545

46-
models = extension_models + mixins + [base_model]
46+
models = [base_model] + extension_models + mixins
4747

4848
# Handle GET requests
4949
if all([issubclass(m, APIRequest) for m in models]):
50-
return make_dataclass(model_name, [], bases=tuple(models))
50+
return attr.make_class(model_name, attrs={}, bases=tuple(models))
5151

5252
# Handle POST requests
5353
elif all([issubclass(m, BaseModel) for m in models]):
@@ -86,43 +86,38 @@ def create_post_request_model(
8686
)
8787

8888

89-
@dataclass
89+
@attr.s
9090
class CollectionUri(APIRequest):
9191
"""Get or delete collection."""
9292

93-
collection_id: Annotated[str, Path(description="Collection ID")]
93+
collection_id: Annotated[str, Path(description="Collection ID")] = attr.ib()
9494

9595

96-
@dataclass
96+
@attr.s
9797
class ItemUri(APIRequest):
9898
"""Get or delete item."""
9999

100-
collection_id: Annotated[str, Path(description="Collection ID")]
101-
item_id: Annotated[str, Path(description="Item ID")]
100+
collection_id: Annotated[str, Path(description="Collection ID")] = attr.ib()
101+
item_id: Annotated[str, Path(description="Item ID")] = attr.ib()
102102

103103

104-
@dataclass
104+
@attr.s
105105
class EmptyRequest(APIRequest):
106106
"""Empty request."""
107107

108108
...
109109

110110

111-
@dataclass
111+
@attr.s
112112
class ItemCollectionUri(APIRequest):
113113
"""Get item collection."""
114114

115-
collection_id: Annotated[str, Path(description="Collection ID")]
116-
limit: Annotated[int, Query()] = 10
117-
bbox: Annotated[Optional[BBox], Query()] = None
118-
datetime: Annotated[Optional[DateTimeType], Query()] = None
119-
120-
def __post_init__(self):
121-
"""convert attributes."""
122-
if self.bbox:
123-
self.bbox = str2bbox(self.bbox) # type: ignore
124-
if self.datetime:
125-
self.datetime = str_to_interval(self.datetime) # type: ignore
115+
collection_id: Annotated[str, Path(description="Collection ID")] = attr.ib()
116+
limit: Annotated[int, Query()] = attr.ib(default=10)
117+
bbox: Annotated[Optional[BBox], Query()] = attr.ib(default=None, converter=str2bbox)
118+
datetime: Annotated[Optional[DateTimeType], Query()] = attr.ib(
119+
default=None, converter=str_to_interval
120+
)
126121

127122

128123
class GeoJSONResponse(JSONResponse):

stac_fastapi/api/tests/test_app.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
from dataclasses import dataclass
21
from datetime import datetime
32
from typing import List, Optional, Union
43

4+
import attr
55
import pytest
66
from fastapi import Path, Query
77
from fastapi.testclient import TestClient
88
from pydantic import ValidationError
99
from stac_pydantic import api
10+
from typing_extensions import Annotated
1011

1112
from stac_fastapi.api import app
1213
from stac_fastapi.api.models import (
@@ -328,25 +329,25 @@ def item_collection(
328329
def test_request_model(AsyncTestCoreClient):
329330
"""Test if request models are passed correctly."""
330331

331-
@dataclass
332+
@attr.s
332333
class CollectionsRequest(APIRequest):
333-
user: str = Query(...)
334+
user: Annotated[str, Query(...)] = attr.ib()
334335

335-
@dataclass
336+
@attr.s
336337
class CollectionRequest(APIRequest):
337-
collection_id: str = Path(description="Collection ID")
338-
user: str = Query(...)
338+
collection_id: Annotated[str, Path(description="Collection ID")] = attr.ib()
339+
user: Annotated[str, Query(...)] = attr.ib()
339340

340-
@dataclass
341+
@attr.s
341342
class ItemsRequest(APIRequest):
342-
collection_id: str = Path(description="Collection ID")
343-
user: str = Query(...)
343+
collection_id: Annotated[str, Path(description="Collection ID")] = attr.ib()
344+
user: Annotated[str, Query(...)] = attr.ib()
344345

345-
@dataclass
346+
@attr.s
346347
class ItemRequest(APIRequest):
347-
collection_id: str = Path(description="Collection ID")
348-
item_id: str = Path(description="Item ID")
349-
user: str = Query(...)
348+
collection_id: Annotated[str, Path(description="Collection ID")] = attr.ib()
349+
item_id: Annotated[str, Path(description="Item ID")] = attr.ib()
350+
user: Annotated[str, Query(...)] = attr.ib()
350351

351352
test_app = app.StacApi(
352353
settings=ApiSettings(),

stac_fastapi/api/tests/test_models.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
import json
22

33
import pytest
4-
from fastapi import Depends, FastAPI
4+
from fastapi import Depends, FastAPI, HTTPException
55
from fastapi.testclient import TestClient
66
from pydantic import ValidationError
77

88
from stac_fastapi.api.models import create_get_request_model, create_post_request_model
9-
from stac_fastapi.extensions.core.filter.filter import FilterExtension
10-
from stac_fastapi.extensions.core.sort.sort import SortExtension
9+
from stac_fastapi.extensions.core import FieldsExtension, FilterExtension, SortExtension
1110
from stac_fastapi.types.search import BaseSearchGetRequest, BaseSearchPostRequest
1211

1312

1413
def test_create_get_request_model():
15-
extensions = [FilterExtension()]
16-
request_model = create_get_request_model(extensions, BaseSearchGetRequest)
14+
request_model = create_get_request_model(
15+
extensions=[FilterExtension(), FieldsExtension()],
16+
base_model=BaseSearchGetRequest,
17+
)
1718

1819
model = request_model(
1920
collections="test1,test2",
@@ -35,6 +36,9 @@ def test_create_get_request_model():
3536
assert model.collections == ["test1", "test2"]
3637
assert model.filter_crs == "epsg:4326"
3738

39+
with pytest.raises(HTTPException):
40+
request_model(datetime="yo")
41+
3842
app = FastAPI()
3943

4044
@app.get("/test")
@@ -62,8 +66,10 @@ def route(model=Depends(request_model)):
6266
[(None, True), ({"test": "test"}, True), ("test==test", False), ([], False)],
6367
)
6468
def test_create_post_request_model(filter, passes):
65-
extensions = [FilterExtension()]
66-
request_model = create_post_request_model(extensions, BaseSearchPostRequest)
69+
request_model = create_post_request_model(
70+
extensions=[FilterExtension(), FieldsExtension()],
71+
base_model=BaseSearchPostRequest,
72+
)
6773

6874
if not passes:
6975
with pytest.raises(ValidationError):
@@ -100,8 +106,10 @@ def test_create_post_request_model(filter, passes):
100106
],
101107
)
102108
def test_create_post_request_model_nested_fields(sortby, passes):
103-
extensions = [SortExtension()]
104-
request_model = create_post_request_model(extensions, BaseSearchPostRequest)
109+
request_model = create_post_request_model(
110+
extensions=[SortExtension()],
111+
base_model=BaseSearchPostRequest,
112+
)
105113

106114
if not passes:
107115
with pytest.raises(ValidationError):

stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/request.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""Request model for the Aggregation extension."""
22

3-
from dataclasses import dataclass
43
from typing import List, Optional
54

5+
import attr
66
from fastapi import Query
77
from pydantic import Field
88
from typing_extensions import Annotated
@@ -14,17 +14,13 @@
1414
)
1515

1616

17-
@dataclass
17+
@attr.s
1818
class AggregationExtensionGetRequest(BaseSearchGetRequest):
1919
"""Aggregation Extension GET request model."""
2020

21-
aggregations: Annotated[Optional[str], Query()] = None
22-
23-
def __post_init__(self):
24-
"""convert attributes."""
25-
super().__post_init__()
26-
if self.aggregations:
27-
self.aggregations = str2list(self.aggregations) # type: ignore
21+
aggregations: Annotated[Optional[str], Query()] = attr.ib(
22+
default=None, converter=str2list
23+
)
2824

2925

3026
class AggregationExtensionPostRequest(BaseSearchPostRequest):

stac_fastapi/extensions/stac_fastapi/extensions/core/fields/request.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""Request models for the fields extension."""
22

33
import warnings
4-
from dataclasses import dataclass
54
from typing import Dict, Optional, Set
65

6+
import attr
77
from fastapi import Query
88
from pydantic import BaseModel, Field
99
from typing_extensions import Annotated
@@ -70,16 +70,11 @@ def filter_fields(self) -> Dict:
7070
}
7171

7272

73-
@dataclass
73+
@attr.s
7474
class FieldsExtensionGetRequest(APIRequest):
7575
"""Additional fields for the GET request."""
7676

77-
fields: Annotated[Optional[str], Query()] = None
78-
79-
def __post_init__(self):
80-
"""convert attributes."""
81-
if self.fields:
82-
self.fields = str2list(self.fields) # type: ignore
77+
fields: Annotated[Optional[str], Query()] = attr.ib(default=None, converter=str2list)
8378

8479

8580
class FieldsExtensionPostRequest(BaseModel):

0 commit comments

Comments
 (0)