Skip to content

Commit 9b5b002

Browse files
author
Jim Fulton
authored
fix: unnest failed in some cases (with table references failed when there were no other references to refrenced tables in a query) (#290)
1 parent 5e9f4c2 commit 9b5b002

File tree

7 files changed

+208
-50
lines changed

7 files changed

+208
-50
lines changed

setup.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,11 @@ def readme():
4545
return f.read()
4646

4747

48-
extras = dict(geography=["GeoAlchemy2", "shapely"], alembic=["alembic"], tests=["pytz"])
48+
extras = dict(
49+
geography=["GeoAlchemy2", "shapely"],
50+
alembic=["alembic"],
51+
tests=["packaging", "pytz"],
52+
)
4953
extras["all"] = set(itertools.chain.from_iterable(extras.values()))
5054

5155
setup(
@@ -85,7 +89,7 @@ def readme():
8589
],
8690
extras_require=extras,
8791
python_requires=">=3.6, <3.10",
88-
tests_require=["pytz"],
92+
tests_require=["packaging", "pytz"],
8993
entry_points={
9094
"sqlalchemy.dialects": ["bigquery = sqlalchemy_bigquery:BigQueryDialect"]
9195
},

sqlalchemy_bigquery/__init__.py

+22-20
Original file line numberDiff line numberDiff line change
@@ -24,40 +24,42 @@
2424

2525
from .base import BigQueryDialect, dialect # noqa
2626
from .base import (
27-
STRING,
27+
ARRAY,
28+
BIGNUMERIC,
2829
BOOL,
2930
BOOLEAN,
31+
BYTES,
32+
DATE,
33+
DATETIME,
34+
FLOAT,
35+
FLOAT64,
3036
INT64,
3137
INTEGER,
32-
FLOAT64,
33-
FLOAT,
34-
TIMESTAMP,
35-
DATETIME,
36-
DATE,
37-
BYTES,
38-
TIME,
39-
RECORD,
4038
NUMERIC,
41-
BIGNUMERIC,
39+
RECORD,
40+
STRING,
41+
TIME,
42+
TIMESTAMP,
4243
)
4344

4445
__all__ = [
46+
"ARRAY",
47+
"BIGNUMERIC",
4548
"BigQueryDialect",
46-
"STRING",
4749
"BOOL",
4850
"BOOLEAN",
51+
"BYTES",
52+
"DATE",
53+
"DATETIME",
54+
"FLOAT",
55+
"FLOAT64",
4956
"INT64",
5057
"INTEGER",
51-
"FLOAT64",
52-
"FLOAT",
53-
"TIMESTAMP",
54-
"DATETIME",
55-
"DATE",
56-
"BYTES",
57-
"TIME",
58-
"RECORD",
5958
"NUMERIC",
60-
"BIGNUMERIC",
59+
"RECORD",
60+
"STRING",
61+
"TIME",
62+
"TIMESTAMP",
6163
]
6264

6365
try:

sqlalchemy_bigquery/base.py

+81-20
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@
6262

6363
FIELD_ILLEGAL_CHARACTERS = re.compile(r"[^\w]+")
6464

65+
TABLE_VALUED_ALIAS_ALIASES = "bigquery_table_valued_alias_aliases"
66+
6567

6668
def assert_(cond, message="Assertion failed"): # pragma: NO COVER
6769
if not cond:
@@ -114,39 +116,41 @@ def format_label(self, label, name=None):
114116

115117

116118
_type_map = {
117-
"STRING": types.String,
118-
"BOOL": types.Boolean,
119+
"ARRAY": types.ARRAY,
120+
"BIGNUMERIC": types.Numeric,
119121
"BOOLEAN": types.Boolean,
120-
"INT64": types.Integer,
121-
"INTEGER": types.Integer,
122+
"BOOL": types.Boolean,
123+
"BYTES": types.BINARY,
124+
"DATETIME": types.DATETIME,
125+
"DATE": types.DATE,
122126
"FLOAT64": types.Float,
123127
"FLOAT": types.Float,
128+
"INT64": types.Integer,
129+
"INTEGER": types.Integer,
130+
"NUMERIC": types.Numeric,
131+
"RECORD": types.JSON,
132+
"STRING": types.String,
124133
"TIMESTAMP": types.TIMESTAMP,
125-
"DATETIME": types.DATETIME,
126-
"DATE": types.DATE,
127-
"BYTES": types.BINARY,
128134
"TIME": types.TIME,
129-
"RECORD": types.JSON,
130-
"NUMERIC": types.Numeric,
131-
"BIGNUMERIC": types.Numeric,
132135
}
133136

134137
# By convention, dialect-provided types are spelled with all upper case.
135-
STRING = _type_map["STRING"]
136-
BOOL = _type_map["BOOL"]
138+
ARRAY = _type_map["ARRAY"]
139+
BIGNUMERIC = _type_map["NUMERIC"]
137140
BOOLEAN = _type_map["BOOLEAN"]
138-
INT64 = _type_map["INT64"]
139-
INTEGER = _type_map["INTEGER"]
141+
BOOL = _type_map["BOOL"]
142+
BYTES = _type_map["BYTES"]
143+
DATETIME = _type_map["DATETIME"]
144+
DATE = _type_map["DATE"]
140145
FLOAT64 = _type_map["FLOAT64"]
141146
FLOAT = _type_map["FLOAT"]
147+
INT64 = _type_map["INT64"]
148+
INTEGER = _type_map["INTEGER"]
149+
NUMERIC = _type_map["NUMERIC"]
150+
RECORD = _type_map["RECORD"]
151+
STRING = _type_map["STRING"]
142152
TIMESTAMP = _type_map["TIMESTAMP"]
143-
DATETIME = _type_map["DATETIME"]
144-
DATE = _type_map["DATE"]
145-
BYTES = _type_map["BYTES"]
146153
TIME = _type_map["TIME"]
147-
RECORD = _type_map["RECORD"]
148-
NUMERIC = _type_map["NUMERIC"]
149-
BIGNUMERIC = _type_map["NUMERIC"]
150154

151155
try:
152156
_type_map["GEOGRAPHY"] = GEOGRAPHY
@@ -246,6 +250,56 @@ def visit_insert(self, insert_stmt, asfrom=False, **kw):
246250
insert_stmt, asfrom=False, **kw
247251
)
248252

253+
def visit_table_valued_alias(self, element, **kw):
254+
# When using table-valued functions, like UNNEST, BigQuery requires a
255+
# FROM for any table referenced in the function, including expressions
256+
# in function arguments.
257+
#
258+
# For example, given SQLAlchemy code:
259+
#
260+
# print(
261+
# select([func.unnest(foo.c.objects).alias('foo_objects').column])
262+
# .compile(engine))
263+
#
264+
# Left to it's own devices, SQLAlchemy would outout:
265+
#
266+
# SELECT `foo_objects`
267+
# FROM unnest(`foo`.`objects`) AS `foo_objects`
268+
#
269+
# But BigQuery diesn't understand the `foo` reference unless
270+
# we add as reference to `foo` in the FROM:
271+
#
272+
# SELECT foo_objects
273+
# FROM `foo`, UNNEST(`foo`.`objects`) as foo_objects
274+
#
275+
# This is tricky because:
276+
# 1. We have to find the table references.
277+
# 2. We can't know practically if there's already a FROM for a table.
278+
#
279+
# We leverage visit_column to find a table reference. Whenever we find
280+
# one, we create an alias for it, so as not to conflict with an existing
281+
# reference if one is present.
282+
#
283+
# This requires communicating between this function and visit_column.
284+
# We do this by sticking a dictionary in the keyword arguments.
285+
# This dictionary:
286+
# a. Tells visit_column that it's an a table-valued alias expresssion, and
287+
# b. Gives it a place to record the aliases it creates.
288+
#
289+
# This function creates aliases in the FROM list for any aliases recorded
290+
# by visit_column.
291+
292+
kw[TABLE_VALUED_ALIAS_ALIASES] = {}
293+
ret = super().visit_table_valued_alias(element, **kw)
294+
aliases = kw.pop(TABLE_VALUED_ALIAS_ALIASES)
295+
if aliases:
296+
aliases = ", ".join(
297+
f"{self.preparer.quote(tablename)} {self.preparer.quote(alias)}"
298+
for tablename, alias in aliases.items()
299+
)
300+
ret = f"{aliases}, {ret}"
301+
return ret
302+
249303
def visit_column(
250304
self,
251305
column,
@@ -281,6 +335,13 @@ def visit_column(
281335
tablename = table.name
282336
if isinstance(tablename, elements._truncated_label):
283337
tablename = self._truncated_identifier("alias", tablename)
338+
elif TABLE_VALUED_ALIAS_ALIASES in kwargs:
339+
aliases = kwargs[TABLE_VALUED_ALIAS_ALIASES]
340+
if tablename not in aliases:
341+
aliases[tablename] = self.anon_map[
342+
f"{TABLE_VALUED_ALIAS_ALIASES} {tablename}"
343+
]
344+
tablename = aliases[tablename]
284345

285346
return self.preparer.quote(tablename) + "." + name
286347

tests/sqlalchemy_dialect_compliance/test_dialect_compliance.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import datetime
2121
import mock
22+
import packaging.version
2223
import pytest
2324
import pytz
2425
import sqlalchemy
@@ -41,7 +42,7 @@
4142
)
4243

4344

44-
if sqlalchemy.__version__ < "1.4":
45+
if packaging.version.parse(sqlalchemy.__version__) < packaging.version.parse("1.4"):
4546
from sqlalchemy.testing.suite import LimitOffsetTest as _LimitOffsetTest
4647

4748
class LimitOffsetTest(_LimitOffsetTest):

tests/system/test_sqlalchemy_bigquery.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@
2828
from sqlalchemy.sql import expression, select, literal_column
2929
from sqlalchemy.exc import NoSuchTableError
3030
from sqlalchemy.orm import sessionmaker
31+
import packaging.version
3132
from pytz import timezone
3233
import pytest
3334
import sqlalchemy
3435
import datetime
3536
import decimal
3637

37-
3838
ONE_ROW_CONTENTS_EXPANDED = [
3939
588,
4040
datetime.datetime(2013, 10, 10, 11, 27, 16, tzinfo=timezone("UTC")),
@@ -725,3 +725,31 @@ class MyTable(Base):
725725
)
726726

727727
assert sorted(db.query(sqlalchemy.distinct(MyTable.my_column)).all()) == expected
728+
729+
730+
@pytest.mark.skipif(
731+
packaging.version.parse(sqlalchemy.__version__) < packaging.version.parse("1.4"),
732+
reason="unnest (and other table-valued-function) support required version 1.4",
733+
)
734+
def test_unnest(engine, bigquery_dataset):
735+
from sqlalchemy import select, func, String
736+
from sqlalchemy_bigquery import ARRAY
737+
738+
conn = engine.connect()
739+
metadata = MetaData()
740+
table = Table(
741+
f"{bigquery_dataset}.test_unnest", metadata, Column("objects", ARRAY(String)),
742+
)
743+
metadata.create_all(engine)
744+
conn.execute(
745+
table.insert(), [dict(objects=["a", "b", "c"]), dict(objects=["x", "y"])]
746+
)
747+
query = select([func.unnest(table.c.objects).alias("foo_objects").column])
748+
compiled = str(query.compile(engine))
749+
assert " ".join(compiled.strip().split()) == (
750+
f"SELECT `foo_objects`"
751+
f" FROM"
752+
f" `{bigquery_dataset}.test_unnest` `{bigquery_dataset}.test_unnest_1`,"
753+
f" unnest(`{bigquery_dataset}.test_unnest_1`.`objects`) AS `foo_objects`"
754+
)
755+
assert sorted(r[0] for r in conn.execute(query)) == ["a", "b", "c", "x", "y"]

tests/unit/conftest.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,24 @@
2121
import mock
2222
import sqlite3
2323

24+
import packaging.version
2425
import pytest
2526
import sqlalchemy
2627

2728
import fauxdbi
2829

29-
sqlalchemy_version_info = tuple(map(int, sqlalchemy.__version__.split(".")))
30+
sqlalchemy_version = packaging.version.parse(sqlalchemy.__version__)
3031
sqlalchemy_1_3_or_higher = pytest.mark.skipif(
31-
sqlalchemy_version_info < (1, 3), reason="requires sqlalchemy 1.3 or higher"
32+
sqlalchemy_version < packaging.version.parse("1.3"),
33+
reason="requires sqlalchemy 1.3 or higher",
3234
)
3335
sqlalchemy_1_4_or_higher = pytest.mark.skipif(
34-
sqlalchemy_version_info < (1, 4), reason="requires sqlalchemy 1.4 or higher"
36+
sqlalchemy_version < packaging.version.parse("1.4"),
37+
reason="requires sqlalchemy 1.4 or higher",
3538
)
3639
sqlalchemy_before_1_4 = pytest.mark.skipif(
37-
sqlalchemy_version_info >= (1, 4), reason="requires sqlalchemy 1.3 or lower"
40+
sqlalchemy_version >= packaging.version.parse("1.4"),
41+
reason="requires sqlalchemy 1.3 or lower",
3842
)
3943

4044

0 commit comments

Comments
 (0)