Skip to content

Commit

Permalink
fix: support executemany (#324)
Browse files Browse the repository at this point in the history
* fix: support executemany

* fix: flake8

* fix: flake8

* fix: test support executemany(

* fix: local variable referenced before assignment

* test: add more executemany test

* fix: flake8

* fix: review

* fix: lint
  • Loading branch information
zeromake authored and jettify committed Dec 1, 2018
1 parent 2446872 commit 9e48e85
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 27 deletions.
102 changes: 77 additions & 25 deletions aiomysql/sa/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
from ..utils import _TransactionContextManager, _SAConnectionContextManager


def noop(k):
return k


class SAConnection:

def __init__(self, connection, engine, compiled_cache=None):
Expand Down Expand Up @@ -64,16 +68,79 @@ def execute(self, query, *multiparams, **params):
coro = self._execute(query, *multiparams, **params)
return _SAConnectionContextManager(coro)

def _base_params(self, query, dp, compiled, is_update):
"""
handle params
"""
if dp and isinstance(dp, (list, tuple)):
if is_update:
dp = {c.key: pval for c, pval in zip(query.table.c, dp)}
else:
raise exc.ArgumentError(
"Don't mix sqlalchemy SELECT "
"clause with positional "
"parameters"
)
compiled_params = compiled.construct_params(dp)
processors = compiled._bind_processors
params = [{
key: processors.get(key, noop)(compiled_params[key])
for key in compiled_params
}]
post_processed_params = self._dialect.execute_sequence_format(params)
return post_processed_params[0]

async def _executemany(self, query, dps, cursor):
"""
executemany
"""
result_map = None
if isinstance(query, str):
await cursor.executemany(query, dps)
elif isinstance(query, DDLElement):
raise exc.ArgumentError(
"Don't mix sqlalchemy DDL clause "
"and execution with parameters"
)
elif isinstance(query, ClauseElement):
compiled = query.compile(dialect=self._dialect)
params = []
is_update = isinstance(query, UpdateBase)
for dp in dps:
params.append(
self._base_params(
query,
dp,
compiled,
is_update,
)
)
await cursor.executemany(str(compiled), params)
result_map = compiled._result_columns
else:
raise exc.ArgumentError(
"sql statement should be str or "
"SQLAlchemy data "
"selection/modification clause"
)
ret = await create_result_proxy(
self,
cursor,
self._dialect,
result_map
)
self._weak_results.add(ret)
return ret

async def _execute(self, query, *multiparams, **params):
cursor = await self._connection.cursor()
dp = _distill_params(multiparams, params)
if len(dp) > 1:
raise exc.ArgumentError("aiomysql doesn't support executemany")
return await self._executemany(query, dp, cursor)
elif dp:
dp = dp[0]

result_map = None

if isinstance(query, str):
await cursor.execute(query, dp or None)
elif isinstance(query, ClauseElement):
Expand All @@ -90,35 +157,20 @@ async def _execute(self, query, *multiparams, **params):
compiled = query.compile(dialect=self._dialect)

if not isinstance(query, DDLElement):
if dp and isinstance(dp, (list, tuple)):
if isinstance(query, UpdateBase):
dp = {c.key: pval
for c, pval in zip(query.table.c, dp)}
else:
raise exc.ArgumentError("Don't mix sqlalchemy SELECT "
"clause with positional "
"parameters")
compiled_parameters = [compiled.construct_params(
dp)]
processed_parameters = []
processors = compiled._bind_processors
for compiled_params in compiled_parameters:
params = {key: (processors[key](compiled_params[key])
if key in processors
else compiled_params[key])
for key in compiled_params}
processed_parameters.append(params)
post_processed_params = self._dialect.execute_sequence_format(
processed_parameters)
post_processed_params = self._base_params(
query,
dp,
compiled,
isinstance(query, UpdateBase)
)
result_map = compiled._result_columns

else:
if dp:
raise exc.ArgumentError("Don't mix sqlalchemy DDL clause "
"and execution with parameters")
post_processed_params = [compiled.construct_params()]
post_processed_params = compiled.construct_params()
result_map = None
await cursor.execute(str(compiled), post_processed_params[0])
await cursor.execute(str(compiled), post_processed_params)
else:
raise exc.ArgumentError("sql statement should be str or "
"SQLAlchemy data "
Expand Down
26 changes: 24 additions & 2 deletions tests/sa/test_sa_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from sqlalchemy import MetaData, Table, Column, Integer, String
from sqlalchemy.schema import DropTable, CreateTable
from sqlalchemy.sql.expression import bindparam


meta = MetaData()
Expand Down Expand Up @@ -269,10 +270,31 @@ async def go():
def test_raw_insert_with_executemany(self):
async def go():
conn = await self.connect()
# with self.assertRaises(sa.ArgumentError):
await conn.execute(
"INSERT INTO sa_tbl (id, name) VALUES (%(id)s, %(name)s)",
[{"id": 2, "name": 'third'}, {"id": 3, "name": 'forth'}])
await conn.execute(
tbl.update().where(
tbl.c.id == bindparam("id")
).values(
{"name": bindparam("name")}
),
[
{"id": 2, "name": "t2"},
{"id": 3, "name": "t3"}
]
)
with self.assertRaises(sa.ArgumentError):
await conn.execute(
DropTable(tbl),
[{}, {}]
)
with self.assertRaises(sa.ArgumentError):
await conn.execute(
"INSERT INTO sa_tbl (id, name) VALUES (%(id)s, %(name)s)",
[(2, 'third'), (3, 'forth')])
{},
[{}, {}]
)
self.loop.run_until_complete(go())

def test_raw_select_with_wildcard(self):
Expand Down

0 comments on commit 9e48e85

Please sign in to comment.