Skip to content

Commit

Permalink
fix: the unnest function lost needed type information (#298)
Browse files Browse the repository at this point in the history
  • Loading branch information
jimfulton authored Aug 26, 2021
1 parent 6ffcef6 commit 1233182
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 0 deletions.
17 changes: 17 additions & 0 deletions sqlalchemy_bigquery/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
from google.api_core.exceptions import NotFound

import sqlalchemy
import sqlalchemy.sql.expression
import sqlalchemy.sql.functions
import sqlalchemy.sql.sqltypes
import sqlalchemy.sql.type_api
from sqlalchemy.exc import NoSuchTableError
Expand Down Expand Up @@ -1092,6 +1094,21 @@ def get_view_definition(self, connection, view_name, schema=None, **kw):
return view.view_query


class unnest(sqlalchemy.sql.functions.GenericFunction):
def __init__(self, *args, **kwargs):
expr = kwargs.pop("expr", None)
if expr is not None:
args = (expr,) + args
if len(args) != 1:
raise TypeError("The unnest function requires a single argument.")
arg = args[0]
if isinstance(arg, sqlalchemy.sql.expression.ColumnElement):
if not isinstance(arg.type, sqlalchemy.sql.sqltypes.ARRAY):
raise TypeError("The argument to unnest must have an ARRAY type.")
self.type = arg.type.item_type
super().__init__(*args, **kwargs)


dialect = BigQueryDialect

try:
Expand Down
51 changes: 51 additions & 0 deletions tests/unit/test_sqlalchemy_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from google.cloud import bigquery
from google.cloud.bigquery.dataset import DatasetListItem
from google.cloud.bigquery.table import TableListItem
import packaging.version
import pytest
import sqlalchemy

Expand Down Expand Up @@ -178,3 +179,53 @@ def test_follow_dialect_attribute_convention():

assert sqlalchemy_bigquery.dialect is sqlalchemy_bigquery.BigQueryDialect
assert sqlalchemy_bigquery.base.dialect is sqlalchemy_bigquery.BigQueryDialect


@pytest.mark.parametrize(
"args,kw,error",
[
((), {}, "The unnest function requires a single argument."),
((1, 1), {}, "The unnest function requires a single argument."),
((1,), {"expr": 1}, "The unnest function requires a single argument."),
((1, 1), {"expr": 1}, "The unnest function requires a single argument."),
(
(),
{"expr": sqlalchemy.Column("x", sqlalchemy.String)},
"The argument to unnest must have an ARRAY type.",
),
(
(sqlalchemy.Column("x", sqlalchemy.String),),
{},
"The argument to unnest must have an ARRAY type.",
),
],
)
def test_unnest_function_errors(args, kw, error):
# Make sure the unnest function is registered with SQLAlchemy, which
# happens when sqlalchemy_bigquery is imported.
import sqlalchemy_bigquery # noqa

with pytest.raises(TypeError, match=error):
sqlalchemy.func.unnest(*args, **kw)


@pytest.mark.parametrize(
"args,kw",
[
((), {"expr": sqlalchemy.Column("x", sqlalchemy.ARRAY(sqlalchemy.String))}),
((sqlalchemy.Column("x", sqlalchemy.ARRAY(sqlalchemy.String)),), {}),
],
)
def test_unnest_function(args, kw):
# Make sure the unnest function is registered with SQLAlchemy, which
# happens when sqlalchemy_bigquery is imported.
import sqlalchemy_bigquery # noqa

f = sqlalchemy.func.unnest(*args, **kw)
assert isinstance(f.type, sqlalchemy.String)
if packaging.version.parse(sqlalchemy.__version__) >= packaging.version.parse(
"1.4"
):
assert isinstance(
sqlalchemy.select([f]).subquery().c.unnest.type, sqlalchemy.String
)

0 comments on commit 1233182

Please sign in to comment.