Skip to content

Commit cf2984e

Browse files
authored
Merge pull request #35 from igorbenav/34-dont-return-arrays-as-top-level-responses
34 dont return arrays as top level responses
2 parents 4ffd63c + b6593fb commit cf2984e

File tree

12 files changed

+223
-115
lines changed

12 files changed

+223
-115
lines changed

README.md

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
- 👜 Easy client-side caching
4949
- 🚦 ARQ integration for task queue
5050
- ⚙️ Efficient querying (only queries what's needed)
51+
- ⎘ Out of the box pagination support
5152
- 👮 FastAPI docs behind authentication and hidden based on the environment
5253
- 🦾 Easily extendable
5354
- 🤸‍♂️ Flexible
@@ -65,7 +66,6 @@
6566
#### Features
6667
- [ ] Add a Rate Limiter decorator
6768
- [ ] Add mongoDB support
68-
- [x] Add support in schema_to_select for a list of column names
6969

7070
#### Tests
7171
- [ ] Add Ruff linter
@@ -569,20 +569,48 @@ crud_users = CRUDUser(User)
569569

570570
When actually using the crud in an endpoint, to get data you just pass the database connection and the attributes as kwargs:
571571
```python
572-
# Here I'm getting the users with email == user.email
572+
# Here I'm getting the first user with email == user.email (email is unique in this case)
573573
user = await crud_users.get(db=db, email=user.email)
574574
```
575575

576576
To get a list of objects with the attributes, you should use the get_multi:
577577
```python
578-
# Here I'm getting 100 users with the name David except for the first 3
578+
# Here I'm getting at most 10 users with the name 'User Userson' except for the first 3
579579
user = await crud_users.get_multi(
580580
db=db,
581581
offset=3,
582582
limit=100,
583-
name="David"
583+
name="User Userson"
584584
)
585585
```
586+
> **Warning**
587+
> Note that get_multi returns a python `dict`.
588+
589+
Which will return a python dict with the following structure:
590+
```javascript
591+
{
592+
"data": [
593+
{
594+
"id": 4,
595+
"name": "User Userson",
596+
"username": "userson4",
597+
"email": "[email protected]",
598+
"profile_image_url": "https://profileimageurl.com"
599+
},
600+
{
601+
"id": 5,
602+
"name": "User Userson",
603+
"username": "userson5",
604+
"email": "[email protected]",
605+
"profile_image_url": "https://profileimageurl.com"
606+
}
607+
],
608+
"total_count": 2,
609+
"has_more": false,
610+
"page": 1,
611+
"items_per_page": 10
612+
}
613+
```
586614

587615
To create, you pass a `CreateSchemaType` object with the attributes, such as a `UserCreate` pydantic schema:
588616
```python
@@ -606,6 +634,15 @@ To just check if there is at least one row that matches a certain set of attribu
606634
crud_users.exists(db=db, email=user@example.com)
607635
```
608636

637+
You can also get the count of a certain object with the specified filter:
638+
```python
639+
# Here I'm getting the count of users with the name 'User Userson'
640+
user = await crud_users.count(
641+
db=db,
642+
name="User Userson"
643+
)
644+
```
645+
609646
To update you pass an `object` which may be a `pydantic schema` or just a regular `dict`, and the kwargs.
610647
You will update with `objects` the rows that match your `kwargs`.
611648
```python
@@ -696,7 +733,7 @@ async def sample_endpoint(request: Request, my_id: int):
696733

697734
The way it works is:
698735
- the data is saved in redis with the following cache key: `sample_data:{my_id}`
699-
- then the the time to expire is set as 3600 seconds (that's the default)
736+
- then the time to expire is set as 3600 seconds (that's the default)
700737

701738
Another option is not passing the `resource_id_name`, but passing the `resource_id_type` (default int):
702739
```python

docker-compose.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ services:
3737
- ./src/.env
3838
volumes:
3939
- postgres-data:/var/lib/postgresql/data
40+
ports:
41+
- "5432:5432"
4042

4143
redis:
4244
image: redis:alpine

src/app/api/v1/posts.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from app.crud.crud_users import crud_users
1313
from app.api.exceptions import privileges_exception
1414
from app.core.cache import cache
15+
from app.core.models import PaginatedListResponse
1516

1617
router = fastapi.APIRouter(tags=["posts"])
1718

@@ -27,29 +28,45 @@ async def write_post(
2728
if db_user is None:
2829
raise HTTPException(status_code=404, detail="User not found")
2930

30-
if current_user.id != db_user.id:
31+
if current_user.id != db_user["id"]:
3132
raise privileges_exception
3233

3334
post_internal_dict = post.model_dump()
34-
post_internal_dict["created_by_user_id"] = db_user.id
35+
post_internal_dict["created_by_user_id"] = db_user["id"]
3536

3637
post_internal = PostCreateInternal(**post_internal_dict)
3738
return await crud_posts.create(db=db, object=post_internal)
3839

3940

40-
@router.get("/{username}/posts", response_model=List[PostRead])
41+
@router.get("/{username}/posts", response_model=PaginatedListResponse[PostRead])
4142
@cache(key_prefix="{username}_posts", resource_id_name="username")
4243
async def read_posts(
4344
request: Request,
44-
username: str,
45-
db: Annotated[AsyncSession, Depends(async_get_db)]
45+
username: str,
46+
db: Annotated[AsyncSession, Depends(async_get_db)],
47+
page: int = 1,
48+
items_per_page: int = 10
4649
):
4750
db_user = await crud_users.get(db=db, schema_to_select=UserRead, username=username, is_deleted=False)
48-
if db_user is None:
51+
if not db_user:
4952
raise HTTPException(status_code=404, detail="User not found")
50-
51-
posts = await crud_posts.get_multi(db=db, schema_to_select=PostRead, created_by_user_id=db_user.id, is_deleted=False)
52-
return posts
53+
54+
posts_data = await crud_posts.get_multi(
55+
db=db,
56+
offset=(page - 1) * items_per_page,
57+
limit=items_per_page,
58+
schema_to_select=PostRead,
59+
created_by_user_id=db_user["id"],
60+
is_deleted=False
61+
)
62+
63+
return {
64+
"data": posts_data["data"],
65+
"total_count": posts_data["total_count"],
66+
"has_more": (page * items_per_page) < posts_data["total_count"],
67+
"page": page,
68+
"items_per_page": items_per_page
69+
}
5370

5471

5572
@router.get("/{username}/post/{id}", response_model=PostRead)
@@ -64,10 +81,10 @@ async def read_post(
6481
if db_user is None:
6582
raise HTTPException(status_code=404, detail="User not found")
6683

67-
db_post = await crud_posts.get(db=db, schema_to_select=PostRead, id=id, created_by_user_id=db_user.id, is_deleted=False)
84+
db_post = await crud_posts.get(db=db, schema_to_select=PostRead, id=id, created_by_user_id=db_user["id"], is_deleted=False)
6885
if db_post is None:
6986
raise HTTPException(status_code=404, detail="Post not found")
70-
87+
7188
return db_post
7289

7390

@@ -89,7 +106,7 @@ async def patch_post(
89106
if db_user is None:
90107
raise HTTPException(status_code=404, detail="User not found")
91108

92-
if current_user.id != db_user.id:
109+
if current_user.id != db_user["id"]:
93110
raise privileges_exception
94111

95112
db_post = await crud_posts.get(db=db, schema_to_select=PostRead, id=id, is_deleted=False)

src/app/api/v1/users.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from app.core.security import get_password_hash
1212
from app.crud.crud_users import crud_users
1313
from app.api.exceptions import privileges_exception
14+
from app.core.models import PaginatedListResponse
1415

1516
router = fastapi.APIRouter(tags=["users"])
1617

@@ -36,10 +37,28 @@ async def write_user(
3637
return await crud_users.create(db=db, object=user_internal)
3738

3839

39-
@router.get("/users", response_model=List[UserRead])
40-
async def read_users(request: Request, db: Annotated[AsyncSession, Depends(async_get_db)]):
41-
users = await crud_users.get_multi(db=db, schema_to_select=UserRead, is_deleted=False)
42-
return users
40+
@router.get("/users", response_model=PaginatedListResponse[UserRead])
41+
async def read_users(
42+
request: Request,
43+
db: Annotated[AsyncSession, Depends(async_get_db)],
44+
page: int = 1,
45+
items_per_page: int = 10
46+
):
47+
users_data = await crud_users.get_multi(
48+
db=db,
49+
offset=(page - 1) * items_per_page,
50+
limit=items_per_page,
51+
schema_to_select=UserRead,
52+
is_deleted=False
53+
)
54+
55+
return {
56+
"data": users_data["data"],
57+
"total_count": users_data["total_count"],
58+
"has_more": (page * items_per_page) < users_data["total_count"],
59+
"page": page,
60+
"items_per_page": items_per_page
61+
}
4362

4463

4564
@router.get("/user/me/", response_model=UserRead)

src/app/core/cache.py

Lines changed: 10 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import re
77

88
from fastapi import Request, Response
9+
from fastapi.encoders import jsonable_encoder
910
from redis.asyncio import Redis, ConnectionPool
10-
from sqlalchemy.orm import class_mapper, DeclarativeBase
1111
from fastapi import FastAPI
1212
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
1313

@@ -18,40 +18,6 @@
1818
pool: ConnectionPool | None = None
1919
client: Redis | None = None
2020

21-
def _serialize_sqlalchemy_object(obj: DeclarativeBase) -> Dict[str, Any]:
22-
"""
23-
Serialize a SQLAlchemy DeclarativeBase object to a dictionary.
24-
25-
Parameters
26-
----------
27-
obj: DeclarativeBase
28-
The SQLAlchemy DeclarativeBase object to be serialized.
29-
30-
Returns
31-
-------
32-
Dict[str, Any]
33-
A dictionary containing the serialized attributes of the object.
34-
35-
Note
36-
----
37-
- Datetime objects are converted to ISO 8601 string format.
38-
- UUID objects are converted to strings before serializing to JSON.
39-
"""
40-
if isinstance(obj, DeclarativeBase):
41-
data = {}
42-
for column in class_mapper(obj.__class__).columns:
43-
value = getattr(obj, column.name)
44-
45-
if isinstance(value, datetime):
46-
value = value.isoformat()
47-
48-
if isinstance(value, UUID):
49-
value = str(value)
50-
51-
data[column.name] = value
52-
return data
53-
54-
5521
def _infer_resource_id(kwargs: Dict[str, Any], resource_id_type: Union[type, str]) -> Union[None, int, str]:
5622
"""
5723
Infer the resource ID from a dictionary of keyword arguments.
@@ -236,8 +202,8 @@ async def sample_endpoint(request: Request, resource_id: int):
236202
This decorator caches the response data of the endpoint function using a unique cache key.
237203
The cached data is retrieved for GET requests, and the cache is invalidated for other types of requests.
238204
239-
Note:
240-
- For caching lists of objects, ensure that the response is a list of objects, and the decorator will handle caching accordingly.
205+
Note
206+
----
241207
- resource_id_type is used only if resource_id is not passed.
242208
"""
243209
def wrapper(func: Callable) -> Callable:
@@ -250,32 +216,26 @@ async def inner(request: Request, *args, **kwargs) -> Response:
250216

251217
formatted_key_prefix = _format_prefix(key_prefix, kwargs)
252218
cache_key = f"{formatted_key_prefix}:{resource_id}"
253-
254219
if request.method == "GET":
255220
if to_invalidate_extra:
256221
raise InvalidRequestError
257222

258223
cached_data = await client.get(cache_key)
259224
if cached_data:
225+
print("cache hit")
260226
return json.loads(cached_data.decode())
261-
227+
262228
result = await func(request, *args, **kwargs)
263229

264230
if request.method == "GET":
265-
if to_invalidate_extra:
266-
raise InvalidRequestError
231+
serializable_data = jsonable_encoder(result)
232+
serialized_data = json.dumps(serializable_data)
267233

268-
if isinstance(result, list):
269-
serialized_data = json.dumps(
270-
[_serialize_sqlalchemy_object(obj) for obj in result]
271-
)
272-
else:
273-
serialized_data = json.dumps(
274-
_serialize_sqlalchemy_object(result)
275-
)
276-
277234
await client.set(cache_key, serialized_data)
278235
await client.expire(cache_key, expiration)
236+
237+
serialized_data = json.loads(serialized_data)
238+
279239
else:
280240
await client.delete(cache_key)
281241
if to_invalidate_extra:

src/app/core/models.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,22 @@
1+
from typing import TypeVar, Generic, List
12
import uuid as uuid_pkg
23
from datetime import datetime
34

4-
from pydantic import BaseModel, Field
5+
from pydantic import BaseModel, Field, field_serializer
56
from sqlalchemy import text
67

7-
from app.core.database import Base
8+
ReadSchemaType = TypeVar("ReadSchemaType", bound=BaseModel)
9+
10+
class ListResponse(BaseModel, Generic[ReadSchemaType]):
11+
data: List[ReadSchemaType]
12+
13+
14+
class PaginatedListResponse(ListResponse[ReadSchemaType]):
15+
total_count: int
16+
has_more: bool
17+
page: int | None = None
18+
items_per_page: int | None = None
19+
820

921
class HealthCheck(BaseModel):
1022
name: str
@@ -47,7 +59,16 @@ class TimestampModel(BaseModel):
4759
}
4860
)
4961

50-
62+
@field_serializer("created_at")
63+
def serialize_dt(self, created_at: datetime | None, _info):
64+
return created_at.isoformat()
65+
66+
@field_serializer("updated_at")
67+
def serialize_updated_at(self, updated_at: datetime | None, _info):
68+
if updated_at is not None:
69+
return updated_at.isoformat()
70+
71+
5172
class PersistentDeletion(BaseModel):
5273
deleted_at: datetime | None = Field(
5374
default=None,
@@ -58,3 +79,8 @@ class PersistentDeletion(BaseModel):
5879
)
5980

6081
is_deleted: bool = False
82+
83+
@field_serializer('deleted_at')
84+
def serialize_dates(self, deleted_at: datetime | None, _info):
85+
if deleted_at is not None:
86+
return deleted_at.isoformat()

0 commit comments

Comments
 (0)