diff --git a/.github/tox-uv/action.yml b/.github/tox-uv/action.yml
new file mode 100644
index 0000000..c356cb8
--- /dev/null
+++ b/.github/tox-uv/action.yml
@@ -0,0 +1,10 @@
+name: Setup tox-uv
+description: Setup tox-uv tool so tox uses uv to install dependencies
+runs:
+ using: composite
+ steps:
+ - name: โก๏ธ setup uv
+ uses: ./.github/uv
+ - name: โ๏ธ install tox-uv
+ shell: bash
+ run: uv tool install tox --with tox-uv
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 0f8ebf0..700c14e 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -6,16 +6,21 @@ on:
jobs:
Tests:
+ strategy:
+ matrix:
+ tox_env: [default, sqlmodel]
runs-on: ubuntu-latest
steps:
- name: ๐ฅ checkout
uses: actions/checkout@v4
- - name: ๐ง setup uv
- uses: ./.github/uv
- - name: ๐งช pytest
- run: uv run pytest --cov fastsqla --cov-report=term-missing --cov-report=xml
+ - name: ๐ง setup tox-uv
+ uses: ./.github/tox-uv
+ - name: ๐งช tox -e ${{ matrix.tox_env }}
+ run: uv run tox -e ${{ matrix.tox_env }}
- name: "๐ codecov: upload test coverage"
uses: codecov/codecov-action@v4.2.0
+ with:
+ flags: ${{ matrix.tox_env }}
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
diff --git a/README.md b/README.md
index c41dfce..cdc1f7c 100644
--- a/README.md
+++ b/README.md
@@ -15,8 +15,10 @@ _Async SQLAlchemy 2.0+ for FastAPI โ boilerplate, pagination, and seamless ses
-----------------------------------------------------------------------------------------
-`FastSQLA` is an [`SQLAlchemy 2.0+`](https://docs.sqlalchemy.org/en/20/) extension for
-[`FastAPI`](https://fastapi.tiangolo.com/).
+`FastSQLA` is an async [`SQLAlchemy 2.0+`](https://docs.sqlalchemy.org/en/20/)
+extension for [`FastAPI`](https://fastapi.tiangolo.com/) with built-in pagination,
+[`SQLModel`](http://sqlmodel.tiangolo.com/) support and more.
+
It streamlines the configuration and asynchronous connection to relational databases by
providing boilerplate and intuitive helpers. Additionally, it offers built-in
customizable pagination and automatically manages the `SQLAlchemy` session lifecycle
@@ -74,13 +76,13 @@ following [`SQLAlchemy`'s best practices](https://docs.sqlalchemy.org/en/20/orm/
async def get_heros(paginate:Paginate):
return await paginate(select(Hero))
```
-
+
-
+
๐ `/heros?offset=10&limit=10` ๐
-
+
-
+
```json
{
"data": [
@@ -119,6 +121,32 @@ following [`SQLAlchemy`'s best practices](https://docs.sqlalchemy.org/en/20/orm/
* Session lifecycle management: session is commited on request success or rollback on
failure.
+* [`SQLModel`](http://sqlmodel.tiangolo.com/) support:
+ ```python
+ ...
+ from fastsqla import Item, Page, Paginate, Session
+ from sqlmodel import Field, SQLModel
+ ...
+
+ class Hero(SQLModel, table=True):
+ id: int | None = Field(default=None, primary_key=True)
+ name: str
+ secret_identity: str
+ age: int
+
+
+ @app.get("/heroes", response_model=Page[Hero])
+ async def get_heroes(paginate: Paginate):
+ return await paginate(select(Hero))
+
+
+ @app.get("/heroes/{hero_id}", response_model=Item[Hero])
+ async def get_hero(session: Session, hero_id: int):
+ hero = await session.get(Hero, hero_id)
+ if hero is None:
+ raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
+ return {"data": hero}
+ ```
## Installing
diff --git a/docs/orm.md b/docs/orm.md
index f4c9e4f..f508067 100644
--- a/docs/orm.md
+++ b/docs/orm.md
@@ -5,5 +5,5 @@
::: fastsqla.Base
options:
heading_level: false
- show_source: false
- show_bases: false
+ show_source: true
+ show_bases: true
diff --git a/docs/pagination.md b/docs/pagination.md
index 38fcf45..7b43482 100644
--- a/docs/pagination.md
+++ b/docs/pagination.md
@@ -6,3 +6,73 @@
options:
heading_level: false
show_source: false
+
+### `SQLAlchemy` example
+
+``` py title="example.py" hl_lines="25 26 27"
+from fastapi import FastAPI
+from fastsqla import Base, Paginate, Page, lifespan
+from pydantic import BaseModel
+from sqlalchemy import select
+from sqlalchemy.orm import Mapped, mapped_column
+
+app = FastAPI(lifespan=lifespan)
+
+class Hero(Base):
+ __tablename__ = "hero"
+ id: Mapped[int] = mapped_column(primary_key=True)
+ name: Mapped[str] = mapped_column(unique=True)
+ secret_identity: Mapped[str]
+ age: Mapped[int]
+
+
+class HeroModel(HeroBase):
+ model_config = ConfigDict(from_attributes=True)
+ id: int
+ name: str
+ secret_identity: str
+ age: int
+
+
+@app.get("/heros", response_model=Page[HeroModel]) # (1)!
+async def list_heros(paginate: Paginate): # (2)!
+ return await paginate(select(Hero)) # (3)!
+```
+
+1. The endpoint returns a `Page` model of `HeroModel`.
+2. Just define an argument with type `Paginate` to get an async `paginate` function
+ injected in your endpoint function.
+3. Await the `paginate` function with the `SQLAlchemy` select statement to get the
+ paginated result.
+
+To add filtering, just add whatever query parameters you need to the endpoint:
+
+```python
+@app.get("/heros", response_model=Page[HeroModel])
+async def list_heros(paginate: Paginate, age:int | None = None):
+ stmt = select(Hero)
+ if age:
+ stmt = stmt.where(Hero.age == age)
+ return await paginate(stmt)
+```
+
+### `SQLModel` example
+
+```python
+from fastapi import FastAPI
+from fastsqla import Page, Paginate, Session
+from sqlmodel import Field, SQLModel
+from sqlalchemy import select
+
+
+class Hero(SQLModel, table=True):
+ id: int | None = Field(default=None, primary_key=True)
+ name: str
+ secret_identity: str
+ age: int
+
+
+@app.get("/heroes", response_model=Page[Hero])
+async def get_heroes(paginate: Paginate):
+ return await paginate(select(Hero))
+```
diff --git a/pyproject.toml b/pyproject.toml
index cb07363..8477397 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -48,6 +48,7 @@ docs = [
"mkdocs-material>=9.5.50",
"mkdocstrings[python]>=0.27.0",
]
+sqlmodel = ["sqlmodel>=0.0.22"]
[tool.uv]
package = true
@@ -71,7 +72,10 @@ dev-dependencies = [
pytest-watch = { git = "https://github.com/styleseat/pytest-watch", rev = "0342193" }
[tool.pytest.ini_options]
-asyncio_mode = 'auto'
+asyncio_mode = "auto"
+asyncio_default_fixture_loop_scope = "function"
+
+filterwarnings = ["ignore::DeprecationWarning:"]
[tool.coverage.run]
branch = true
@@ -86,3 +90,17 @@ version_toml = ["pyproject.toml:project.version"]
[tool.semantic_release.changelog.default_templates]
changelog_file = "./docs/changelog.md"
+
+[tool.tox]
+legacy_tox_ini = """
+[tox]
+envlist = { default, sqlmodel }
+
+[testenv]
+passenv = CI
+runner = uv-venv-lock-runner
+commands =
+ pytest --cov fastsqla --cov-report=term-missing --cov-report=xml
+extras:
+ sqlmodel: sqlmodel
+"""
diff --git a/src/fastsqla.py b/src/fastsqla.py
index 636afd4..1a42e37 100644
--- a/src/fastsqla.py
+++ b/src/fastsqla.py
@@ -17,6 +17,15 @@
from sqlalchemy.orm import DeclarativeBase
from structlog import get_logger
+logger = get_logger(__name__)
+
+try:
+ from sqlmodel.ext.asyncio.session import AsyncSession
+
+except ImportError:
+ pass
+
+
__all__ = [
"Base",
"Collection",
@@ -30,7 +39,7 @@
"open_session",
]
-SessionFactory = async_sessionmaker(expire_on_commit=False)
+SessionFactory = async_sessionmaker(expire_on_commit=False, class_=AsyncSession)
logger = get_logger(__name__)
@@ -56,6 +65,10 @@ class Hero(Base):
* [ORM Quick Start](https://docs.sqlalchemy.org/en/20/orm/quickstart.html)
* [Declarative Mapping](https://docs.sqlalchemy.org/en/20/orm/mapping_styles.html#declarative-mapping)
+
+ !!! note
+
+ You don't need this if you use [`SQLModel`](http://sqlmodel.tiangolo.com/).
"""
__abstract__ = True
@@ -142,7 +155,7 @@ async def lifespan(app:FastAPI) -> AsyncGenerator[dict, None]:
@asynccontextmanager
async def open_session() -> AsyncGenerator[AsyncSession, None]:
- """An asynchronous context manager that opens a new `SQLAlchemy` async session.
+ """Async context manager that opens a new `SQLAlchemy` or `SQLModel` async session.
To the contrary of the [`Session`][fastsqla.Session] dependency which can only be
used in endpoints, `open_session` can be used anywhere such as in background tasks.
@@ -152,6 +165,16 @@ async def open_session() -> AsyncGenerator[AsyncSession, None]:
In all cases, it closes the session and returns the associated connection to the
connection pool.
+
+ Returns:
+ When `SQLModel` is not installed, an async generator that yields an
+ [`SQLAlchemy AsyncSession`][sqlalchemy.ext.asyncio.AsyncSession].
+
+ When `SQLModel` is installed, an async generator that yields an
+ [`SQLModel AsyncSession`](https://github.com/fastapi/sqlmodel/blob/main/sqlmodel/ext/asyncio/session.py#L32)
+ which inherits from [`SQLAlchemy AsyncSession`][sqlalchemy.ext.asyncio.AsyncSession].
+
+
```python
from fastsqla import open_session
@@ -191,12 +214,12 @@ async def new_session() -> AsyncGenerator[AsyncSession, None]:
Session = Annotated[AsyncSession, Depends(new_session)]
-"""A dependency used exclusively in endpoints to get an `SQLAlchemy` session.
+"""Dependency used exclusively in endpoints to get an `SQLAlchemy` or `SQLModel` session.
`Session` is a [`FastAPI` dependency](https://fastapi.tiangolo.com/tutorial/dependencies/)
-that provides an asynchronous `SQLAlchemy` session.
+that provides an asynchronous `SQLAlchemy` session or `SQLModel` one if it's installed.
By defining an argument with type `Session` in an endpoint, `FastAPI` will automatically
-inject an `SQLAlchemy` async session into the endpoint.
+inject an async session into the endpoint.
At the end of request handling:
@@ -336,9 +359,9 @@ async def paginate(stmt: Select) -> Page:
Paginate = Annotated[PaginateType[T], Depends(new_pagination())]
"""A dependency used in endpoints to paginate `SQLAlchemy` select queries.
-It adds `offset`and `limit` query parameters to the endpoint, which are used to paginate.
-The model returned by the endpoint is a `Page` model. It contains a page of data and
-metadata:
+It adds **`offset`** and **`limit`** query parameters to the endpoint, which are used to
+paginate. The model returned by the endpoint is a `Page` model. It contains a page of
+data and metadata:
```json
{
@@ -351,55 +374,4 @@ async def paginate(stmt: Select) -> Page:
}
}
```
-
------
-
-Example:
-``` py title="example.py" hl_lines="22 23 25"
-from fastsqla import Base, Paginate, Page
-from pydantic import BaseModel
-
-
-class Hero(Base):
- __tablename__ = "hero"
-
-
-class Hero(Base):
- __tablename__ = "hero"
- id: Mapped[int] = mapped_column(primary_key=True)
- name: Mapped[str] = mapped_column(unique=True)
- secret_identity: Mapped[str]
- age: Mapped[int]
-
-
-class HeroModel(HeroBase):
- model_config = ConfigDict(from_attributes=True)
- id: int
-
-
-@app.get("/heros", response_model=Page[HeroModel]) # (1)!
-async def list_heros(paginate: Paginate): # (2)!
- stmt = select(Hero)
- return await paginate(stmt) # (3)!
-```
-
-1. The endpoint returns a `Page` model of `HeroModel`.
-2. Just define an argument with type `Paginate` to get an async `paginate` function
- injected in your endpoint function.
-3. Await the `paginate` function with the `SQLAlchemy` select statement to get the
- paginated result.
-
-To add filtering, just add whatever query parameters you need to the endpoint:
-
-```python
-
-from fastsqla import Paginate, Page
-
-@app.get("/heros", response_model=Page[HeroModel])
-async def list_heros(paginate: Paginate, age:int | None = None):
- stmt = select(Hero)
- if age:
- stmt = stmt.where(Hero.age == age)
- return await paginate(stmt)
-```
"""
diff --git a/tests/conftest.py b/tests/conftest.py
index 1d6335b..0daebf9 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,9 +1,15 @@
from unittest.mock import patch
-from pytest import fixture
+from pytest import fixture, skip
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
+def pytest_configure(config):
+ config.addinivalue_line(
+ "markers", "require_sqlmodel: skip test when sqlmodel is not installed."
+ )
+
+
@fixture
def environ(tmp_path):
values = {
@@ -38,3 +44,19 @@ def tear_down():
Base.metadata.clear()
clear_mappers()
+
+
+try:
+ import sqlmodel # noqa
+except ImportError:
+ is_sqlmodel_installed = False
+else:
+ is_sqlmodel_installed = True
+
+
+@fixture(autouse=True)
+def check_sqlmodel(request):
+ """Skip test marked with mark.require_sqlmodel if sqlmodel is not installed."""
+ marker = request.node.get_closest_marker("require_sqlmodel")
+ if marker and not is_sqlmodel_installed:
+ skip(f"{request.node.nodeid} requires sqlmodel which is not installed.")
diff --git a/tests/integration/test_sqlmodel.py b/tests/integration/test_sqlmodel.py
new file mode 100644
index 0000000..7fb8c30
--- /dev/null
+++ b/tests/integration/test_sqlmodel.py
@@ -0,0 +1,151 @@
+from http import HTTPStatus
+
+from fastapi import HTTPException
+from pytest import fixture, mark
+from sqlalchemy import insert, select, text
+from sqlalchemy.exc import IntegrityError
+from sqlalchemy.ext.automap import automap_base
+
+
+pytestmark = mark.require_sqlmodel
+
+
+@fixture
+def heros_data():
+ return [
+ ("Superman", "Clark Kent", 30),
+ ("Batman", "Bruce Wayne", 35),
+ ("Wonder Woman", "Diana Prince", 30),
+ ("Iron Man", "Tony Stark", 45),
+ ("Spider-Man", "Peter Parker", 25),
+ ("Captain America", "Steve Rogers", 100),
+ ("Black Widow", "Natasha Romanoff", 35),
+ ("Thor", "Thor Odinson", 1500),
+ ("Scarlet Witch", "Wanda Maximoff", 30),
+ ("Doctor Strange", "Stephen Strange", 40),
+ ("The Flash", "Barry Allen", 28),
+ ("Green Lantern", "Hal Jordan", 35),
+ ]
+
+
+@fixture(autouse=True)
+async def setup_tear_down(engine, heros_data):
+ Base = automap_base()
+ async with engine.connect() as conn:
+ await conn.execute(
+ text("""
+ CREATE TABLE hero (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ name TEXT NOT NULL UNIQUE,
+ secret_identity TEXT NOT NULL,
+ age INTEGER NOT NULL
+ )
+ """)
+ )
+
+ await conn.run_sync(Base.prepare)
+
+ Hero = Base.classes.hero
+
+ stmt = insert(Hero).values(
+ [
+ dict(name=name, secret_identity=secret_identity, age=age)
+ for name, secret_identity, age in heros_data
+ ]
+ )
+ await conn.execute(stmt)
+ await conn.commit()
+ yield
+ await conn.execute(text("DROP TABLE hero"))
+
+
+@fixture
+async def app(setup_tear_down, app):
+ from fastsqla import Item, Page, Paginate, Session
+ from sqlmodel import Field, SQLModel
+
+ class Hero(SQLModel, table=True):
+ __table_args__ = {"extend_existing": True}
+ id: int | None = Field(default=None, primary_key=True)
+ name: str
+ secret_identity: str
+ age: int
+
+ @app.get("/heroes", response_model=Page[Hero])
+ async def get_heroes(paginate: Paginate):
+ return await paginate(select(Hero))
+
+ @app.get("/heroes/{hero_id}", response_model=Item[Hero])
+ async def get_hero(session: Session, hero_id: int):
+ hero = await session.get(Hero, hero_id)
+ if hero is None:
+ raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
+ return {"data": hero}
+
+ @app.post("/heroes", response_model=Item[Hero])
+ async def create_hero(session: Session, hero: Hero):
+ session.add(hero)
+ try:
+ await session.flush()
+ except IntegrityError:
+ raise HTTPException(status_code=HTTPStatus.CONFLICT)
+ return {"data": hero}
+
+ return app
+
+
+@mark.parametrize("offset, page_number, items_count", [[0, 1, 10], [10, 2, 2]])
+async def test_pagination(client, heros_data, offset, page_number, items_count):
+ res = await client.get("/heroes", params={"offset": offset})
+ assert res.status_code == 200, (res.status_code, res.content)
+
+ payload = res.json()
+ assert "data" in payload
+ data = payload["data"]
+ assert len(data) == items_count
+
+ for i, hero in enumerate(data):
+ name, secret_identity, age = heros_data[i + offset]
+ assert hero["id"]
+ assert hero["name"] == name
+ assert hero["secret_identity"] == secret_identity
+ assert hero["age"] == age
+
+ assert "meta" in payload
+ assert payload["meta"]["total_items"] == 12
+ assert payload["meta"]["total_pages"] == 2
+ assert payload["meta"]["offset"] == offset
+ assert payload["meta"]["page_number"] == page_number
+
+
+async def test_getting_an_entity_with_session_dependency(client, heros_data):
+ res = await client.get("/heroes/1")
+ assert res.status_code == 200, (res.status_code, res.content)
+
+ payload = res.json()
+ assert "data" in payload
+ data = payload["data"]
+
+ name, secret_identity, age = heros_data[0]
+ assert data["id"] == 1
+ assert data["name"] == name
+ assert data["secret_identity"] == secret_identity
+ assert data["age"] == age
+
+
+async def test_creating_an_entity_with_session_dependency(client):
+ hero = {"name": "Hulk", "secret_identity": "Bruce Banner", "age": 37}
+ res = await client.post("/heroes", json=hero)
+ assert res.status_code == 200, (res.status_code, res.content)
+
+ data = res.json()["data"]
+ assert data["id"] == 13
+ assert data["name"] == hero["name"]
+ assert data["secret_identity"] == hero["secret_identity"]
+ assert data["age"] == hero["age"]
+
+
+async def test_creating_an_entity_with_conflict(client):
+ hero = {"name": "Superman", "secret_identity": "Clark Kent", "age": 30}
+ res = await client.post("/heroes", json=hero)
+ assert res.status_code == HTTPStatus.CONFLICT, (res.status_code, res.content)
diff --git a/uv.lock b/uv.lock
index 3976e36..3220be3 100644
--- a/uv.lock
+++ b/uv.lock
@@ -123,7 +123,7 @@ name = "click"
version = "8.1.7"
source = { registry = "https://pypi.org/simple" }
dependencies = [
- { name = "colorama", marker = "platform_system == 'Windows'" },
+ { name = "colorama", marker = "sys_platform == 'win32'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 }
wheels = [
@@ -278,6 +278,9 @@ docs = [
{ name = "mkdocs-material" },
{ name = "mkdocstrings", extra = ["python"] },
]
+sqlmodel = [
+ { name = "sqlmodel" },
+]
[package.dev-dependencies]
dev = [
@@ -303,6 +306,7 @@ requires-dist = [
{ name = "mkdocs-material", marker = "extra == 'docs'", specifier = ">=9.5.50" },
{ name = "mkdocstrings", extras = ["python"], marker = "extra == 'docs'", specifier = ">=0.27.0" },
{ name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.37" },
+ { name = "sqlmodel", marker = "extra == 'sqlmodel'", specifier = ">=0.0.22" },
{ name = "structlog", specifier = ">=24.4.0" },
]
@@ -615,7 +619,7 @@ version = "1.6.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "click" },
- { name = "colorama", marker = "platform_system == 'Windows'" },
+ { name = "colorama", marker = "sys_platform == 'win32'" },
{ name = "ghp-import" },
{ name = "jinja2" },
{ name = "markdown" },
@@ -1261,6 +1265,19 @@ asyncio = [
{ name = "greenlet" },
]
+[[package]]
+name = "sqlmodel"
+version = "0.0.22"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "pydantic" },
+ { name = "sqlalchemy" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/b5/39/8641040ab0d5e1d8a1c2325ae89a01ae659fc96c61a43d158fb71c9a0bf0/sqlmodel-0.0.22.tar.gz", hash = "sha256:7d37c882a30c43464d143e35e9ecaf945d88035e20117bf5ec2834a23cbe505e", size = 116392 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/dd/b1/3af5104b716c420e40a6ea1b09886cae3a1b9f4538343875f637755cae5b/sqlmodel-0.0.22-py3-none-any.whl", hash = "sha256:a1ed13e28a1f4057cbf4ff6cdb4fc09e85702621d3259ba17b3c230bfb2f941b", size = 28276 },
+]
+
[[package]]
name = "starlette"
version = "0.41.3"