diff --git a/aiomysql/sa/connection.py b/aiomysql/sa/connection.py index d9597aef..f6cabab0 100644 --- a/aiomysql/sa/connection.py +++ b/aiomysql/sa/connection.py @@ -13,6 +13,10 @@ from ..utils import _TransactionContextManager, _SAConnectionContextManager +def noop(k): + return k + + class SAConnection: def __init__(self, connection, engine, compiled_cache=None): @@ -64,16 +68,79 @@ def execute(self, query, *multiparams, **params): coro = self._execute(query, *multiparams, **params) return _SAConnectionContextManager(coro) + def _base_params(self, query, dp, compiled, is_update): + """ + handle params + """ + if dp and isinstance(dp, (list, tuple)): + if is_update: + dp = {c.key: pval for c, pval in zip(query.table.c, dp)} + else: + raise exc.ArgumentError( + "Don't mix sqlalchemy SELECT " + "clause with positional " + "parameters" + ) + compiled_params = compiled.construct_params(dp) + processors = compiled._bind_processors + params = [{ + key: processors.get(key, noop)(compiled_params[key]) + for key in compiled_params + }] + post_processed_params = self._dialect.execute_sequence_format(params) + return post_processed_params[0] + + async def _executemany(self, query, dps, cursor): + """ + executemany + """ + result_map = None + if isinstance(query, str): + await cursor.executemany(query, dps) + elif isinstance(query, DDLElement): + raise exc.ArgumentError( + "Don't mix sqlalchemy DDL clause " + "and execution with parameters" + ) + elif isinstance(query, ClauseElement): + compiled = query.compile(dialect=self._dialect) + params = [] + is_update = isinstance(query, UpdateBase) + for dp in dps: + params.append( + self._base_params( + query, + dp, + compiled, + is_update, + ) + ) + await cursor.executemany(str(compiled), params) + result_map = compiled._result_columns + else: + raise exc.ArgumentError( + "sql statement should be str or " + "SQLAlchemy data " + "selection/modification clause" + ) + ret = await create_result_proxy( + self, + cursor, + self._dialect, + result_map + ) + self._weak_results.add(ret) + return ret + async def _execute(self, query, *multiparams, **params): cursor = await self._connection.cursor() dp = _distill_params(multiparams, params) if len(dp) > 1: - raise exc.ArgumentError("aiomysql doesn't support executemany") + return await self._executemany(query, dp, cursor) elif dp: dp = dp[0] result_map = None - if isinstance(query, str): await cursor.execute(query, dp or None) elif isinstance(query, ClauseElement): @@ -90,35 +157,20 @@ async def _execute(self, query, *multiparams, **params): compiled = query.compile(dialect=self._dialect) if not isinstance(query, DDLElement): - if dp and isinstance(dp, (list, tuple)): - if isinstance(query, UpdateBase): - dp = {c.key: pval - for c, pval in zip(query.table.c, dp)} - else: - raise exc.ArgumentError("Don't mix sqlalchemy SELECT " - "clause with positional " - "parameters") - compiled_parameters = [compiled.construct_params( - dp)] - processed_parameters = [] - processors = compiled._bind_processors - for compiled_params in compiled_parameters: - params = {key: (processors[key](compiled_params[key]) - if key in processors - else compiled_params[key]) - for key in compiled_params} - processed_parameters.append(params) - post_processed_params = self._dialect.execute_sequence_format( - processed_parameters) + post_processed_params = self._base_params( + query, + dp, + compiled, + isinstance(query, UpdateBase) + ) result_map = compiled._result_columns - else: if dp: raise exc.ArgumentError("Don't mix sqlalchemy DDL clause " "and execution with parameters") - post_processed_params = [compiled.construct_params()] + post_processed_params = compiled.construct_params() result_map = None - await cursor.execute(str(compiled), post_processed_params[0]) + await cursor.execute(str(compiled), post_processed_params) else: raise exc.ArgumentError("sql statement should be str or " "SQLAlchemy data " diff --git a/tests/sa/test_sa_connection.py b/tests/sa/test_sa_connection.py index dce50c59..37b823d6 100644 --- a/tests/sa/test_sa_connection.py +++ b/tests/sa/test_sa_connection.py @@ -8,6 +8,7 @@ from sqlalchemy import MetaData, Table, Column, Integer, String from sqlalchemy.schema import DropTable, CreateTable +from sqlalchemy.sql.expression import bindparam meta = MetaData() @@ -269,10 +270,31 @@ async def go(): def test_raw_insert_with_executemany(self): async def go(): conn = await self.connect() + # with self.assertRaises(sa.ArgumentError): + await conn.execute( + "INSERT INTO sa_tbl (id, name) VALUES (%(id)s, %(name)s)", + [{"id": 2, "name": 'third'}, {"id": 3, "name": 'forth'}]) + await conn.execute( + tbl.update().where( + tbl.c.id == bindparam("id") + ).values( + {"name": bindparam("name")} + ), + [ + {"id": 2, "name": "t2"}, + {"id": 3, "name": "t3"} + ] + ) + with self.assertRaises(sa.ArgumentError): + await conn.execute( + DropTable(tbl), + [{}, {}] + ) with self.assertRaises(sa.ArgumentError): await conn.execute( - "INSERT INTO sa_tbl (id, name) VALUES (%(id)s, %(name)s)", - [(2, 'third'), (3, 'forth')]) + {}, + [{}, {}] + ) self.loop.run_until_complete(go()) def test_raw_select_with_wildcard(self):