Skip to content

Commit

Permalink
✨ Update AsyncSession with support for SQLAlchemy 2.0 and new Session
Browse files Browse the repository at this point in the history
  • Loading branch information
tiangolo committed Nov 17, 2023
1 parent baa5e3a commit 80a5c52
Showing 1 changed file with 101 additions and 43 deletions.
144 changes: 101 additions & 43 deletions sqlmodel/ext/asyncio/session.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,49 @@
from typing import Any, Mapping, Optional, Sequence, TypeVar, Union, overload
from typing import (
Any,
Dict,
Mapping,
Optional,
Sequence,
Type,
TypeVar,
Union,
cast,
overload,
)

from sqlalchemy import util
from sqlalchemy.engine.interfaces import _CoreAnyExecuteParams
from sqlalchemy.engine.result import Result, ScalarResult, TupleResult
from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession
from sqlalchemy.ext.asyncio import engine
from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine
from sqlalchemy.ext.asyncio.result import _ensure_sync_result
from sqlalchemy.ext.asyncio.session import _EXECUTE_OPTIONS
from sqlalchemy.orm._typing import OrmExecuteOptionsParameter
from sqlalchemy.sql.base import Executable as _Executable
from sqlalchemy.util.concurrency import greenlet_spawn
from typing_extensions import deprecated

from ...engine.result import Result, ScalarResult
from ...orm.session import Session
from ...sql.base import Executable
from ...sql.expression import Select, SelectOfScalar

_TSelectParam = TypeVar("_TSelectParam")
_TSelectParam = TypeVar("_TSelectParam", bound=Any)


class AsyncSession(_AsyncSession):
sync_session_class: Type[Session] = Session
sync_session: Session

def __init__(
self,
bind: Optional[Union[AsyncConnection, AsyncEngine]] = None,
binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None,
**kw: Any,
):
# All the same code of the original AsyncSession
kw["future"] = True
if bind:
self.bind = bind
bind = engine._get_sync_engine_or_connection(bind) # type: ignore

if binds:
self.binds = binds
binds = {
key: engine._get_sync_engine_or_connection(b) # type: ignore
for key, b in binds.items()
}

self.sync_session = self._proxied = self._assign_proxied( # type: ignore
Session(bind=bind, binds=binds, **kw) # type: ignore
)

@overload
async def exec(
self,
statement: Select[_TSelectParam],
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
bind_arguments: Optional[Dict[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
**kw: Any,
) -> Result[_TSelectParam]:
) -> TupleResult[_TSelectParam]:
...

@overload
Expand All @@ -61,10 +53,9 @@ async def exec(
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
bind_arguments: Optional[Dict[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
**kw: Any,
) -> ScalarResult[_TSelectParam]:
...

Expand All @@ -75,20 +66,87 @@ async def exec(
SelectOfScalar[_TSelectParam],
Executable[_TSelectParam],
],
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[Any, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
**kw: Any,
) -> Union[Result[_TSelectParam], ScalarResult[_TSelectParam]]:
# TODO: the documentation says execution_options accepts a dict, but only
# util.immutabledict has the union() method. Is this a bug in SQLAlchemy?
execution_options = execution_options.union({"prebuffer_rows": True}) # type: ignore

return await greenlet_spawn(
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Dict[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
) -> Union[TupleResult[_TSelectParam], ScalarResult[_TSelectParam]]:
if execution_options:
execution_options = util.immutabledict(execution_options).union(
_EXECUTE_OPTIONS
)
else:
execution_options = _EXECUTE_OPTIONS

result = await greenlet_spawn(
self.sync_session.exec,
statement,
params=params,
execution_options=execution_options,
bind_arguments=bind_arguments,
**kw,
_parent_execute_state=_parent_execute_state,
_add_event=_add_event,
)
result_value = await _ensure_sync_result(
cast(Result[_TSelectParam], result), self.exec
)
return result_value # type: ignore

@deprecated(
"""
🚨 You probably want to use `session.exec()` instead of `session.execute()`.
This is the original SQLAlchemy `session.execute()` method that returns objects
of type `Row`, and that you have to call `scalars()` to get the model objects.
For example:
```Python
heroes = await session.execute(select(Hero)).scalars().all()
```
instead you could use `exec()`:
```Python
heroes = await session.exec(select(Hero)).all()
```
"""
)
async def execute( # type: ignore
self,
statement: _Executable,
params: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[Dict[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
) -> Result[Any]:
"""
🚨 You probably want to use `session.exec()` instead of `session.execute()`.
This is the original SQLAlchemy `session.execute()` method that returns objects
of type `Row`, and that you have to call `scalars()` to get the model objects.
For example:
```Python
heroes = await session.execute(select(Hero)).scalars().all()
```
instead you could use `exec()`:
```Python
heroes = await session.exec(select(Hero)).all()
```
"""
return await super().execute(
statement,
params=params,
execution_options=execution_options,
bind_arguments=bind_arguments,
_parent_execute_state=_parent_execute_state,
_add_event=_add_event,
)

0 comments on commit 80a5c52

Please sign in to comment.