1717from sqlalchemy .orm import DeclarativeBase
1818from structlog import get_logger
1919
20+ logger = get_logger (__name__ )
21+
22+ try :
23+ from sqlmodel .ext .asyncio .session import AsyncSession
24+
25+ except ImportError :
26+ pass
27+
28+
2029__all__ = [
2130 "Base" ,
2231 "Collection" ,
3039 "open_session" ,
3140]
3241
33- SessionFactory = async_sessionmaker (expire_on_commit = False )
42+ SessionFactory = async_sessionmaker (expire_on_commit = False , class_ = AsyncSession )
3443
3544logger = get_logger (__name__ )
3645
@@ -56,6 +65,10 @@ class Hero(Base):
5665
5766 * [ORM Quick Start](https://docs.sqlalchemy.org/en/20/orm/quickstart.html)
5867 * [Declarative Mapping](https://docs.sqlalchemy.org/en/20/orm/mapping_styles.html#declarative-mapping)
68+
69+ !!! note
70+
71+ You don't need this if you use [`SQLModel`](http://sqlmodel.tiangolo.com/).
5972 """
6073
6174 __abstract__ = True
@@ -142,7 +155,7 @@ async def lifespan(app:FastAPI) -> AsyncGenerator[dict, None]:
142155
143156@asynccontextmanager
144157async def open_session () -> AsyncGenerator [AsyncSession , None ]:
145- """An asynchronous context manager that opens a new `SQLAlchemy` async session.
158+ """Async context manager that opens a new `SQLAlchemy` or `SQLModel ` async session.
146159
147160 To the contrary of the [`Session`][fastsqla.Session] dependency which can only be
148161 used in endpoints, `open_session` can be used anywhere such as in background tasks.
@@ -152,6 +165,16 @@ async def open_session() -> AsyncGenerator[AsyncSession, None]:
152165 In all cases, it closes the session and returns the associated connection to the
153166 connection pool.
154167
168+
169+ Returns:
170+ When `SQLModel` is not installed, an async generator that yields an
171+ [`SQLAlchemy AsyncSession`][sqlalchemy.ext.asyncio.AsyncSession].
172+
173+ When `SQLModel` is installed, an async generator that yields an
174+ [`SQLModel AsyncSession`](https://github.com/fastapi/sqlmodel/blob/main/sqlmodel/ext/asyncio/session.py#L32)
175+ which inherits from [`SQLAlchemy AsyncSession`][sqlalchemy.ext.asyncio.AsyncSession].
176+
177+
155178 ```python
156179 from fastsqla import open_session
157180
@@ -191,12 +214,12 @@ async def new_session() -> AsyncGenerator[AsyncSession, None]:
191214
192215
193216Session = Annotated [AsyncSession , Depends (new_session )]
194- """A dependency used exclusively in endpoints to get an `SQLAlchemy` session.
217+ """Dependency used exclusively in endpoints to get an `SQLAlchemy` or `SQLModel ` session.
195218
196219`Session` is a [`FastAPI` dependency](https://fastapi.tiangolo.com/tutorial/dependencies/)
197- that provides an asynchronous `SQLAlchemy` session.
220+ that provides an asynchronous `SQLAlchemy` session or `SQLModel` one if it's installed .
198221By defining an argument with type `Session` in an endpoint, `FastAPI` will automatically
199- inject an `SQLAlchemy` async session into the endpoint.
222+ inject an async session into the endpoint.
200223
201224At the end of request handling:
202225
@@ -336,9 +359,9 @@ async def paginate(stmt: Select) -> Page:
336359Paginate = Annotated [PaginateType [T ], Depends (new_pagination ())]
337360"""A dependency used in endpoints to paginate `SQLAlchemy` select queries.
338361
339- It adds `offset`and `limit` query parameters to the endpoint, which are used to paginate.
340- The model returned by the endpoint is a `Page` model. It contains a page of data and
341- metadata:
362+ It adds ** `offset`** and ** `limit`** query parameters to the endpoint, which are used to
363+ paginate. The model returned by the endpoint is a `Page` model. It contains a page of
364+ data and metadata:
342365
343366```json
344367{
@@ -351,55 +374,4 @@ async def paginate(stmt: Select) -> Page:
351374 }
352375}
353376```
354-
355- -----
356-
357- Example:
358- ``` py title="example.py" hl_lines="22 23 25"
359- from fastsqla import Base, Paginate, Page
360- from pydantic import BaseModel
361-
362-
363- class Hero(Base):
364- __tablename__ = "hero"
365-
366-
367- class Hero(Base):
368- __tablename__ = "hero"
369- id: Mapped[int] = mapped_column(primary_key=True)
370- name: Mapped[str] = mapped_column(unique=True)
371- secret_identity: Mapped[str]
372- age: Mapped[int]
373-
374-
375- class HeroModel(HeroBase):
376- model_config = ConfigDict(from_attributes=True)
377- id: int
378-
379-
380- @app.get("/heros", response_model=Page[HeroModel]) # (1)!
381- async def list_heros(paginate: Paginate): # (2)!
382- stmt = select(Hero)
383- return await paginate(stmt) # (3)!
384- ```
385-
386- 1. The endpoint returns a `Page` model of `HeroModel`.
387- 2. Just define an argument with type `Paginate` to get an async `paginate` function
388- injected in your endpoint function.
389- 3. Await the `paginate` function with the `SQLAlchemy` select statement to get the
390- paginated result.
391-
392- To add filtering, just add whatever query parameters you need to the endpoint:
393-
394- ```python
395-
396- from fastsqla import Paginate, Page
397-
398- @app.get("/heros", response_model=Page[HeroModel])
399- async def list_heros(paginate: Paginate, age:int | None = None):
400- stmt = select(Hero)
401- if age:
402- stmt = stmt.where(Hero.age == age)
403- return await paginate(stmt)
404- ```
405377"""
0 commit comments