Skip to content

Commit 7692704

Browse files
author
Jim Fulton
authored
feat: Handle passing of arrays to in statements more efficiently in SQLAlchemy 1.4 and higher (#253)
1 parent 9b5b002 commit 7692704

File tree

4 files changed

+141
-79
lines changed

4 files changed

+141
-79
lines changed

sqlalchemy_bigquery/base.py

+36-2
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,39 @@ def visit_bindparam(
483483
skip_bind_expression=False,
484484
**kwargs,
485485
):
486+
type_ = bindparam.type
487+
unnest = False
488+
if (
489+
bindparam.expanding
490+
and not isinstance(type_, NullType)
491+
and not literal_binds
492+
):
493+
# Normally, when performing an IN operation, like:
494+
#
495+
# foo IN (some_sequence)
496+
#
497+
# SQAlchemy passes `foo` as a parameter and unpacks
498+
# `some_sequence` and passes each element as a parameter.
499+
# This mechanism is refered to as "expanding". It's
500+
# inefficient and can't handle large arrays. (It's also
501+
# very complicated, but that's not the issue we care about
502+
# here. :) ) BigQuery lets us use arrays directly in this
503+
# context, we just need to call UNNEST on an array when
504+
# it's used in IN.
505+
#
506+
# So, if we get an `expanding` flag, and if we have a known type
507+
# (and don't have literal binds, which are implemented in-line in
508+
# in the SQL), we turn off expanding and we set an unnest flag
509+
# so that we add an UNNEST() call (below).
510+
#
511+
# The NullType/known-type check has to do with some extreme
512+
# edge cases having to do with empty in-lists that get special
513+
# hijinks from SQLAlchemy that we don't want to disturb. :)
514+
if getattr(bindparam, "expand_op", None) is not None:
515+
assert bindparam.expand_op.__name__.endswith("in_op") # in in
516+
bindparam.expanding = False
517+
unnest = True
518+
486519
param = super(BigQueryCompiler, self).visit_bindparam(
487520
bindparam,
488521
within_columns_clause,
@@ -491,7 +524,6 @@ def visit_bindparam(
491524
**kwargs,
492525
)
493526

494-
type_ = bindparam.type
495527
if literal_binds or isinstance(type_, NullType):
496528
return param
497529

@@ -512,7 +544,6 @@ def visit_bindparam(
512544
if bq_type[-1] == ">" and bq_type.startswith("ARRAY<"):
513545
# Values get arrayified at a lower level.
514546
bq_type = bq_type[6:-1]
515-
516547
bq_type = self.__remove_type_parameter(bq_type)
517548

518549
assert_(param != "%s", f"Unexpected param: {param}")
@@ -528,6 +559,9 @@ def visit_bindparam(
528559
assert_(type_ is None)
529560
param = f"%({name}:{bq_type})s"
530561

562+
if unnest:
563+
param = f"UNNEST({param})"
564+
531565
return param
532566

533567

tests/system/test_sqlalchemy_bigquery.py

+21
Original file line numberDiff line numberDiff line change
@@ -727,6 +727,27 @@ class MyTable(Base):
727727
assert sorted(db.query(sqlalchemy.distinct(MyTable.my_column)).all()) == expected
728728

729729

730+
@pytest.mark.skipif(
731+
packaging.version.parse(sqlalchemy.__version__) < packaging.version.parse("1.4"),
732+
reason="requires sqlalchemy 1.4 or higher",
733+
)
734+
def test_huge_in():
735+
engine = sqlalchemy.create_engine("bigquery://")
736+
conn = engine.connect()
737+
try:
738+
assert list(
739+
conn.execute(
740+
sqlalchemy.select([sqlalchemy.literal(-1).in_(list(range(99999)))])
741+
)
742+
) == [(False,)]
743+
except Exception:
744+
error = True
745+
else:
746+
error = False
747+
748+
assert not error, "execution failed"
749+
750+
730751
@pytest.mark.skipif(
731752
packaging.version.parse(sqlalchemy.__version__) < packaging.version.parse("1.4"),
732753
reason="unnest (and other table-valued-function) support required version 1.4",

tests/unit/fauxdbi.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -261,11 +261,20 @@ def __handle_problematic_literal_inserts(
261261
else:
262262
return operation
263263

264-
__handle_unnest = substitute_string_re_method(
265-
r"UNNEST\(\[ ([^\]]+)? \]\)", # UNNEST([ ... ])
266-
flags=re.IGNORECASE,
267-
repl=r"(\1)",
264+
@substitute_re_method(
265+
r"""
266+
UNNEST\(
267+
(
268+
\[ (?P<exp>[^\]]+)? \] # UNNEST([ ... ])
269+
|
270+
([?]) # UNNEST(?)
271+
)
272+
\)
273+
""",
274+
flags=re.IGNORECASE | re.VERBOSE,
268275
)
276+
def __handle_unnest(self, m):
277+
return "(" + (m.group("exp") or "?") + ")"
269278

270279
def __handle_true_false(self, operation):
271280
# Older sqlite versions, like those used on the CI servers

tests/unit/test_select.py

+71-73
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
from conftest import (
3030
setup_table,
31+
sqlalchemy_version,
3132
sqlalchemy_1_3_or_higher,
3233
sqlalchemy_1_4_or_higher,
3334
sqlalchemy_before_1_4,
@@ -214,18 +215,6 @@ def test_disable_quote(faux_conn):
214215
assert faux_conn.test_data["execute"][-1][0] == ("SELECT `t`.foo \nFROM `t`")
215216

216217

217-
def _normalize_in_params(query, params):
218-
# We have to normalize parameter names, because they
219-
# change with sqlalchemy versions.
220-
newnames = sorted(
221-
((p, f"p_{i}") for i, p in enumerate(sorted(params))), key=lambda i: -len(i[0])
222-
)
223-
for old, new in newnames:
224-
query = query.replace(old, new)
225-
226-
return query, {new: params[old] for old, new in newnames}
227-
228-
229218
@sqlalchemy_before_1_4
230219
def test_select_in_lit_13(faux_conn):
231220
[[isin]] = faux_conn.execute(
@@ -240,66 +229,74 @@ def test_select_in_lit_13(faux_conn):
240229

241230

242231
@sqlalchemy_1_4_or_higher
243-
def test_select_in_lit(faux_conn):
244-
[[isin]] = faux_conn.execute(
245-
sqlalchemy.select([sqlalchemy.literal(1).in_([1, 2, 3])])
246-
)
247-
assert isin
248-
assert _normalize_in_params(*faux_conn.test_data["execute"][-1]) == (
249-
"SELECT %(p_0:INT64)s IN "
250-
"UNNEST([ %(p_1:INT64)s, %(p_2:INT64)s, %(p_3:INT64)s ]) AS `anon_1`",
251-
{"p_1": 1, "p_2": 2, "p_3": 3, "p_0": 1},
232+
def test_select_in_lit(faux_conn, last_query):
233+
faux_conn.execute(sqlalchemy.select([sqlalchemy.literal(1).in_([1, 2, 3])]))
234+
last_query(
235+
"SELECT %(param_1:INT64)s IN UNNEST(%(param_2:INT64)s) AS `anon_1`",
236+
{"param_1": 1, "param_2": [1, 2, 3]},
252237
)
253238

254239

255-
def test_select_in_param(faux_conn):
240+
def test_select_in_param(faux_conn, last_query):
256241
[[isin]] = faux_conn.execute(
257242
sqlalchemy.select(
258243
[sqlalchemy.literal(1).in_(sqlalchemy.bindparam("q", expanding=True))]
259244
),
260245
dict(q=[1, 2, 3]),
261246
)
262-
assert isin
263-
assert faux_conn.test_data["execute"][-1] == (
264-
"SELECT %(param_1:INT64)s IN UNNEST("
265-
"[ %(q_1:INT64)s, %(q_2:INT64)s, %(q_3:INT64)s ]"
266-
") AS `anon_1`",
267-
{"param_1": 1, "q_1": 1, "q_2": 2, "q_3": 3},
268-
)
247+
if sqlalchemy_version >= packaging.version.parse("1.4"):
248+
last_query(
249+
"SELECT %(param_1:INT64)s IN UNNEST(%(q:INT64)s) AS `anon_1`",
250+
{"param_1": 1, "q": [1, 2, 3]},
251+
)
252+
else:
253+
assert isin
254+
last_query(
255+
"SELECT %(param_1:INT64)s IN UNNEST("
256+
"[ %(q_1:INT64)s, %(q_2:INT64)s, %(q_3:INT64)s ]"
257+
") AS `anon_1`",
258+
{"param_1": 1, "q_1": 1, "q_2": 2, "q_3": 3},
259+
)
269260

270261

271-
def test_select_in_param1(faux_conn):
262+
def test_select_in_param1(faux_conn, last_query):
272263
[[isin]] = faux_conn.execute(
273264
sqlalchemy.select(
274265
[sqlalchemy.literal(1).in_(sqlalchemy.bindparam("q", expanding=True))]
275266
),
276267
dict(q=[1]),
277268
)
278-
assert isin
279-
assert faux_conn.test_data["execute"][-1] == (
280-
"SELECT %(param_1:INT64)s IN UNNEST(" "[ %(q_1:INT64)s ]" ") AS `anon_1`",
281-
{"param_1": 1, "q_1": 1},
282-
)
269+
if sqlalchemy_version >= packaging.version.parse("1.4"):
270+
last_query(
271+
"SELECT %(param_1:INT64)s IN UNNEST(%(q:INT64)s) AS `anon_1`",
272+
{"param_1": 1, "q": [1]},
273+
)
274+
else:
275+
assert isin
276+
last_query(
277+
"SELECT %(param_1:INT64)s IN UNNEST(" "[ %(q_1:INT64)s ]" ") AS `anon_1`",
278+
{"param_1": 1, "q_1": 1},
279+
)
283280

284281

285282
@sqlalchemy_1_3_or_higher
286-
def test_select_in_param_empty(faux_conn):
283+
def test_select_in_param_empty(faux_conn, last_query):
287284
[[isin]] = faux_conn.execute(
288285
sqlalchemy.select(
289286
[sqlalchemy.literal(1).in_(sqlalchemy.bindparam("q", expanding=True))]
290287
),
291288
dict(q=[]),
292289
)
293-
assert not isin
294-
assert faux_conn.test_data["execute"][-1] == (
295-
"SELECT %(param_1:INT64)s IN(NULL) AND (1 != 1) AS `anon_1`"
296-
if (
297-
packaging.version.parse(sqlalchemy.__version__)
298-
>= packaging.version.parse("1.4")
290+
if sqlalchemy_version >= packaging.version.parse("1.4"):
291+
last_query(
292+
"SELECT %(param_1:INT64)s IN UNNEST(%(q:INT64)s) AS `anon_1`",
293+
{"param_1": 1, "q": []},
294+
)
295+
else:
296+
assert not isin
297+
last_query(
298+
"SELECT %(param_1:INT64)s IN UNNEST([ ]) AS `anon_1`", {"param_1": 1}
299299
)
300-
else "SELECT %(param_1:INT64)s IN UNNEST([ ]) AS `anon_1`",
301-
{"param_1": 1},
302-
)
303300

304301

305302
@sqlalchemy_before_1_4
@@ -316,53 +313,54 @@ def test_select_notin_lit13(faux_conn):
316313

317314

318315
@sqlalchemy_1_4_or_higher
319-
def test_select_notin_lit(faux_conn):
320-
[[isnotin]] = faux_conn.execute(
321-
sqlalchemy.select([sqlalchemy.literal(0).notin_([1, 2, 3])])
316+
def test_select_notin_lit(faux_conn, last_query):
317+
faux_conn.execute(sqlalchemy.select([sqlalchemy.literal(0).notin_([1, 2, 3])]))
318+
last_query(
319+
"SELECT (%(param_1:INT64)s NOT IN UNNEST(%(param_2:INT64)s)) AS `anon_1`",
320+
{"param_1": 0, "param_2": [1, 2, 3]},
322321
)
323-
assert isnotin
324322

325-
assert _normalize_in_params(*faux_conn.test_data["execute"][-1]) == (
326-
"SELECT (%(p_0:INT64)s NOT IN "
327-
"UNNEST([ %(p_1:INT64)s, %(p_2:INT64)s, %(p_3:INT64)s ])) AS `anon_1`",
328-
{"p_0": 0, "p_1": 1, "p_2": 2, "p_3": 3},
329-
)
330323

331-
332-
def test_select_notin_param(faux_conn):
324+
def test_select_notin_param(faux_conn, last_query):
333325
[[isnotin]] = faux_conn.execute(
334326
sqlalchemy.select(
335327
[sqlalchemy.literal(1).notin_(sqlalchemy.bindparam("q", expanding=True))]
336328
),
337329
dict(q=[1, 2, 3]),
338330
)
339-
assert not isnotin
340-
assert faux_conn.test_data["execute"][-1] == (
341-
"SELECT (%(param_1:INT64)s NOT IN UNNEST("
342-
"[ %(q_1:INT64)s, %(q_2:INT64)s, %(q_3:INT64)s ]"
343-
")) AS `anon_1`",
344-
{"param_1": 1, "q_1": 1, "q_2": 2, "q_3": 3},
345-
)
331+
if sqlalchemy_version >= packaging.version.parse("1.4"):
332+
last_query(
333+
"SELECT (%(param_1:INT64)s NOT IN UNNEST(%(q:INT64)s)) AS `anon_1`",
334+
{"param_1": 1, "q": [1, 2, 3]},
335+
)
336+
else:
337+
assert not isnotin
338+
last_query(
339+
"SELECT (%(param_1:INT64)s NOT IN UNNEST("
340+
"[ %(q_1:INT64)s, %(q_2:INT64)s, %(q_3:INT64)s ]"
341+
")) AS `anon_1`",
342+
{"param_1": 1, "q_1": 1, "q_2": 2, "q_3": 3},
343+
)
346344

347345

348346
@sqlalchemy_1_3_or_higher
349-
def test_select_notin_param_empty(faux_conn):
347+
def test_select_notin_param_empty(faux_conn, last_query):
350348
[[isnotin]] = faux_conn.execute(
351349
sqlalchemy.select(
352350
[sqlalchemy.literal(1).notin_(sqlalchemy.bindparam("q", expanding=True))]
353351
),
354352
dict(q=[]),
355353
)
356-
assert isnotin
357-
assert faux_conn.test_data["execute"][-1] == (
358-
"SELECT (%(param_1:INT64)s NOT IN(NULL) OR (1 = 1)) AS `anon_1`"
359-
if (
360-
packaging.version.parse(sqlalchemy.__version__)
361-
>= packaging.version.parse("1.4")
354+
if sqlalchemy_version >= packaging.version.parse("1.4"):
355+
last_query(
356+
"SELECT (%(param_1:INT64)s NOT IN UNNEST(%(q:INT64)s)) AS `anon_1`",
357+
{"param_1": 1, "q": []},
358+
)
359+
else:
360+
assert isnotin
361+
last_query(
362+
"SELECT (%(param_1:INT64)s NOT IN UNNEST([ ])) AS `anon_1`", {"param_1": 1}
362363
)
363-
else "SELECT (%(param_1:INT64)s NOT IN UNNEST([ ])) AS `anon_1`",
364-
{"param_1": 1},
365-
)
366364

367365

368366
def test_literal_binds_kwarg_with_an_IN_operator_252(faux_conn):

0 commit comments

Comments
 (0)