diff --git a/sqlmodel/ext/asyncio/session.py b/sqlmodel/ext/asyncio/session.py index f500c44dc2..012d8ef5e4 100644 --- a/sqlmodel/ext/asyncio/session.py +++ b/sqlmodel/ext/asyncio/session.py @@ -1,45 +1,38 @@ -from typing import Any, Mapping, Optional, Sequence, TypeVar, Union, overload +from typing import ( + Any, + Dict, + Mapping, + Optional, + Sequence, + Type, + TypeVar, + Union, + cast, + overload, +) from sqlalchemy import util +from sqlalchemy.engine.interfaces import _CoreAnyExecuteParams +from sqlalchemy.engine.result import Result, ScalarResult, TupleResult from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession -from sqlalchemy.ext.asyncio import engine -from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine +from sqlalchemy.ext.asyncio.result import _ensure_sync_result +from sqlalchemy.ext.asyncio.session import _EXECUTE_OPTIONS +from sqlalchemy.orm._typing import OrmExecuteOptionsParameter +from sqlalchemy.sql.base import Executable as _Executable from sqlalchemy.util.concurrency import greenlet_spawn +from typing_extensions import deprecated -from ...engine.result import Result, ScalarResult from ...orm.session import Session from ...sql.base import Executable from ...sql.expression import Select, SelectOfScalar -_TSelectParam = TypeVar("_TSelectParam") +_TSelectParam = TypeVar("_TSelectParam", bound=Any) class AsyncSession(_AsyncSession): + sync_session_class: Type[Session] = Session sync_session: Session - def __init__( - self, - bind: Optional[Union[AsyncConnection, AsyncEngine]] = None, - binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None, - **kw: Any, - ): - # All the same code of the original AsyncSession - kw["future"] = True - if bind: - self.bind = bind - bind = engine._get_sync_engine_or_connection(bind) # type: ignore - - if binds: - self.binds = binds - binds = { - key: engine._get_sync_engine_or_connection(b) # type: ignore - for key, b in binds.items() - } - - self.sync_session = self._proxied = self._assign_proxied( # type: ignore - Session(bind=bind, binds=binds, **kw) # type: ignore - ) - @overload async def exec( self, @@ -47,11 +40,10 @@ async def exec( *, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, execution_options: Mapping[str, Any] = util.EMPTY_DICT, - bind_arguments: Optional[Mapping[str, Any]] = None, + bind_arguments: Optional[Dict[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - **kw: Any, - ) -> Result[_TSelectParam]: + ) -> TupleResult[_TSelectParam]: ... @overload @@ -61,10 +53,9 @@ async def exec( *, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, execution_options: Mapping[str, Any] = util.EMPTY_DICT, - bind_arguments: Optional[Mapping[str, Any]] = None, + bind_arguments: Optional[Dict[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - **kw: Any, ) -> ScalarResult[_TSelectParam]: ... @@ -75,20 +66,87 @@ async def exec( SelectOfScalar[_TSelectParam], Executable[_TSelectParam], ], + *, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, - execution_options: Mapping[Any, Any] = util.EMPTY_DICT, - bind_arguments: Optional[Mapping[str, Any]] = None, - **kw: Any, - ) -> Union[Result[_TSelectParam], ScalarResult[_TSelectParam]]: - # TODO: the documentation says execution_options accepts a dict, but only - # util.immutabledict has the union() method. Is this a bug in SQLAlchemy? - execution_options = execution_options.union({"prebuffer_rows": True}) # type: ignore - - return await greenlet_spawn( + execution_options: Mapping[str, Any] = util.EMPTY_DICT, + bind_arguments: Optional[Dict[str, Any]] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Union[TupleResult[_TSelectParam], ScalarResult[_TSelectParam]]: + if execution_options: + execution_options = util.immutabledict(execution_options).union( + _EXECUTE_OPTIONS + ) + else: + execution_options = _EXECUTE_OPTIONS + + result = await greenlet_spawn( self.sync_session.exec, statement, params=params, execution_options=execution_options, bind_arguments=bind_arguments, - **kw, + _parent_execute_state=_parent_execute_state, + _add_event=_add_event, + ) + result_value = await _ensure_sync_result( + cast(Result[_TSelectParam], result), self.exec + ) + return result_value # type: ignore + + @deprecated( + """ + 🚨 You probably want to use `session.exec()` instead of `session.execute()`. + + This is the original SQLAlchemy `session.execute()` method that returns objects + of type `Row`, and that you have to call `scalars()` to get the model objects. + + For example: + + ```Python + heroes = await session.execute(select(Hero)).scalars().all() + ``` + + instead you could use `exec()`: + + ```Python + heroes = await session.exec(select(Hero)).all() + ``` + """ + ) + async def execute( # type: ignore + self, + statement: _Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[Dict[str, Any]] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[Any]: + """ + 🚨 You probably want to use `session.exec()` instead of `session.execute()`. + + This is the original SQLAlchemy `session.execute()` method that returns objects + of type `Row`, and that you have to call `scalars()` to get the model objects. + + For example: + + ```Python + heroes = await session.execute(select(Hero)).scalars().all() + ``` + + instead you could use `exec()`: + + ```Python + heroes = await session.exec(select(Hero)).all() + ``` + """ + return await super().execute( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + _parent_execute_state=_parent_execute_state, + _add_event=_add_event, )