Skip to content
Merged
4 changes: 3 additions & 1 deletion django_spanner/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
SQLInsertCompiler as BaseSQLInsertCompiler,
SQLUpdateCompiler as BaseSQLUpdateCompiler,
)
from django.db.utils import DatabaseError
from django.db.utils import DatabaseError, ensure_where_clause


class SQLCompiler(BaseSQLCompiler):
Expand Down Expand Up @@ -90,6 +90,8 @@ def get_combinator_sql(self, combinator, all):
params = []
for part in args_parts:
params.extend(part)

result = ensure_where_clause(result)
return result, params


Expand Down
15 changes: 15 additions & 0 deletions django_spanner/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import django
import sqlparse
from django.core.exceptions import ImproperlyConfigured
from django.utils.version import get_version_tuple

Expand All @@ -18,3 +19,17 @@ def check_django_compatability():
A=django.VERSION[0], B=django.VERSION[1], C=__version__
)
)


def ensure_where_clause(sql):
"""
Cloud Spanner requires a WHERE clause on UPDATE and DELETE statements.
Add a dummy WHERE clause if necessary.
"""
if any(
isinstance(token, sqlparse.sql.Where)
for token in sqlparse.parse(sql)[0]
):
return sql

return sql + " WHERE 1=1"
107 changes: 106 additions & 1 deletion google/cloud/spanner_dbapi/parse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,105 @@

RE_PYFORMAT = re.compile(r"(%s|%\([^\(\)]+\)s)+", re.DOTALL)

SPANNER_RESERVED_KEYWORDS = {
"ALL",
"AND",
"ANY",
"ARRAY",
"AS",
"ASC",
"ASSERT_ROWS_MODIFIED",
"AT",
"BETWEEN",
"BY",
"CASE",
"CAST",
"COLLATE",
"CONTAINS",
"CREATE",
"CROSS",
"CUBE",
"CURRENT",
"DEFAULT",
"DEFINE",
"DESC",
"DISTINCT",
"DROP",
"ELSE",
"END",
"ENUM",
"ESCAPE",
"EXCEPT",
"EXCLUDE",
"EXISTS",
"EXTRACT",
"FALSE",
"FETCH",
"FOLLOWING",
"FOR",
"FROM",
"FULL",
"GROUP",
"GROUPING",
"GROUPS",
"HASH",
"HAVING",
"IF",
"IGNORE",
"IN",
"INNER",
"INTERSECT",
"INTERVAL",
"INTO",
"IS",
"JOIN",
"LATERAL",
"LEFT",
"LIKE",
"LIMIT",
"LOOKUP",
"MERGE",
"NATURAL",
"NEW",
"NO",
"NOT",
"NULL",
"NULLS",
"OF",
"ON",
"OR",
"ORDER",
"OUTER",
"OVER",
"PARTITION",
"PRECEDING",
"PROTO",
"RANGE",
"RECURSIVE",
"RESPECT",
"RIGHT",
"ROLLUP",
"ROWS",
"SELECT",
"SET",
"SOME",
"STRUCT",
"TABLESAMPLE",
"THEN",
"TO",
"TREAT",
"TRUE",
"UNBOUNDED",
"UNION",
"UNNEST",
"USING",
"WHEN",
"WHERE",
"WINDOW",
"WITH",
"WITHIN",
}


def classify_stmt(query):
"""Determine SQL query type.
Expand Down Expand Up @@ -517,13 +616,19 @@ def ensure_where_clause(sql):
"""
Cloud Spanner requires a WHERE clause on UPDATE and DELETE statements.
Add a dummy WHERE clause if necessary.

:type sql: `str`
:param sql: SQL code to check.
"""
if any(
isinstance(token, sqlparse.sql.Where)
for token in sqlparse.parse(sql)[0]
):
return sql
return sql + " WHERE 1=1"

raise ProgrammingError(
"Cloud Spanner requires a WHERE clause when executing DELETE or UPDATE query"
)


def escape_name(name):
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/spanner_dbapi/test_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_do_execute_update(self):
def run_helper(ret_value):
transaction.execute_update.return_value = ret_value
res = cursor._do_execute_update(
transaction=transaction, sql="sql", params=None,
transaction=transaction, sql="SELECT * WHERE true", params={},
)
return res

Expand Down
42 changes: 16 additions & 26 deletions tests/unit/spanner_dbapi/test_parse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,36 +407,26 @@ def test_get_param_types_none(self):
self.assertEqual(get_param_types(None), None)

def test_ensure_where_clause(self):
from google.cloud.spanner_dbapi.exceptions import ProgrammingError
from google.cloud.spanner_dbapi.parse_utils import ensure_where_clause

cases = [
(
"UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1",
"UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1",
),
(
"UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5",
"UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5 WHERE 1=1",
),
(
"UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2",
"UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2",
),
(
"UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
"UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
),
(
"UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
"UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
),
("DELETE * FROM TABLE", "DELETE * FROM TABLE WHERE 1=1"),
]
cases = (
"UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1",
"UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2",
"UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
)
err_cases = (
"UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5",
"DELETE * FROM TABLE",
)
for sql in cases:
with self.subTest(sql=sql):
ensure_where_clause(sql)

for sql, want in cases:
for sql in err_cases:
with self.subTest(sql=sql):
got = ensure_where_clause(sql)
self.assertEqual(got, want)
with self.assertRaises(ProgrammingError):
ensure_where_clause(sql)

def test_escape_name(self):
from google.cloud.spanner_dbapi.parse_utils import escape_name
Expand Down