Skip to content

Commit b93345c

Browse files
committed
Support for conditions in upserts
1 parent ae5b2be commit b93345c

File tree

5 files changed

+91
-27
lines changed

5 files changed

+91
-27
lines changed

psqlextra/compiler.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from collections.abc import Iterable
22

33
from django.core.exceptions import SuspiciousOperation
4-
from django.db.models import Model
4+
from django.db.models import Expression, Model
55
from django.db.models.fields.related import RelatedField
66
from django.db.models.sql.compiler import SQLInsertCompiler, SQLUpdateCompiler
77
from django.db.utils import ProgrammingError
@@ -150,30 +150,31 @@ def _rewrite_insert_on_conflict(
150150
# for conflicts
151151
conflict_target = self._build_conflict_target()
152152
index_predicate = self.query.index_predicate
153+
update_condition = self.query.conflict_update_condition
153154

154-
sql_template = (
155-
"{insert} ON CONFLICT {conflict_target} DO {conflict_action}"
156-
)
155+
rewritten_sql = f"{sql} ON CONFLICT {conflict_target}"
157156

158157
if index_predicate:
159-
sql_template = "{insert} ON CONFLICT {conflict_target} WHERE {index_predicate} DO {conflict_action}"
158+
if isinstance(index_predicate, Expression):
159+
expr_sql, expr_params = self.compile(index_predicate)
160+
rewritten_sql += f" WHERE {expr_sql}"
161+
params += tuple(expr_params)
162+
else:
163+
rewritten_sql += f" WHERE {index_predicate}"
164+
165+
rewritten_sql += f" DO {conflict_action}"
160166

161167
if conflict_action == "UPDATE":
162-
sql_template += " SET {update_columns}"
163-
164-
sql_template += " RETURNING {returning}"
165-
166-
return (
167-
sql_template.format(
168-
insert=sql,
169-
conflict_target=conflict_target,
170-
conflict_action=conflict_action,
171-
update_columns=update_columns,
172-
returning=returning,
173-
index_predicate=index_predicate,
174-
),
175-
params,
176-
)
168+
rewritten_sql += f" SET {update_columns}"
169+
170+
if update_condition:
171+
expr_sql, expr_params = self.compile(update_condition)
172+
rewritten_sql += f" WHERE {expr_sql}"
173+
params += tuple(expr_params)
174+
175+
rewritten_sql += f" RETURNING {returning}"
176+
177+
return (rewritten_sql, params)
177178

178179
def _build_conflict_target(self):
179180
"""Builds the `conflict_target` for the ON CONFLICT clause."""

psqlextra/query.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from django.core.exceptions import SuspiciousOperation
66
from django.db import models, router
7+
from django.db.models import Expression
78
from django.db.models.fields import NOT_PROVIDED
89

910
from .sql import PostgresInsertQuery, PostgresQuery
@@ -24,6 +25,7 @@ def __init__(self, model=None, query=None, using=None, hints=None):
2425

2526
self.conflict_target = None
2627
self.conflict_action = None
28+
self.conflict_update_condition = None
2729
self.index_predicate = None
2830

2931
def annotate(self, **annotations):
@@ -80,7 +82,8 @@ def on_conflict(
8082
self,
8183
fields: ConflictTarget,
8284
action: ConflictAction,
83-
index_predicate: Optional[str] = None,
85+
index_predicate: Optional[Union[Expression, str]] = None,
86+
update_condition: Optional[Expression] = None,
8487
):
8588
"""Sets the action to take when conflicts arise when attempting to
8689
insert/create a new row.
@@ -95,10 +98,14 @@ def on_conflict(
9598
index_predicate:
9699
The index predicate to satisfy an arbiter partial index (i.e. what partial index to use for checking
97100
conflicts)
101+
102+
update_condition:
103+
Only update if this SQL expression evaluates to true.
98104
"""
99105

100106
self.conflict_target = fields
101107
self.conflict_action = action
108+
self.conflict_update_condition = update_condition
102109
self.index_predicate = index_predicate
103110

104111
return self
@@ -250,8 +257,9 @@ def upsert(
250257
self,
251258
conflict_target: ConflictTarget,
252259
fields: dict,
253-
index_predicate: Optional[str] = None,
260+
index_predicate: Optional[Union[Expression, str]] = None,
254261
using: Optional[str] = None,
262+
update_condition: Optional[Expression] = None,
255263
) -> int:
256264
"""Creates a new record or updates the existing one with the specified
257265
data.
@@ -271,21 +279,28 @@ def upsert(
271279
The name of the database connection to
272280
use for this query.
273281
282+
update_condition:
283+
Only update if this SQL expression evaluates to true.
284+
274285
Returns:
275286
The primary key of the row that was created/updated.
276287
"""
277288

278289
self.on_conflict(
279-
conflict_target, ConflictAction.UPDATE, index_predicate
290+
conflict_target,
291+
ConflictAction.UPDATE,
292+
index_predicate=index_predicate,
293+
update_condition=update_condition,
280294
)
281295
return self.insert(**fields, using=using)
282296

283297
def upsert_and_get(
284298
self,
285299
conflict_target: ConflictTarget,
286300
fields: dict,
287-
index_predicate: Optional[str] = None,
301+
index_predicate: Optional[Union[Expression, str]] = None,
288302
using: Optional[str] = None,
303+
update_condition: Optional[Expression] = None,
289304
):
290305
"""Creates a new record or updates the existing one with the specified
291306
data and then gets the row.
@@ -305,23 +320,30 @@ def upsert_and_get(
305320
The name of the database connection to
306321
use for this query.
307322
323+
update_condition:
324+
Only update if this SQL expression evaluates to true.
325+
308326
Returns:
309327
The model instance representing the row
310328
that was created/updated.
311329
"""
312330

313331
self.on_conflict(
314-
conflict_target, ConflictAction.UPDATE, index_predicate
332+
conflict_target,
333+
ConflictAction.UPDATE,
334+
index_predicate=index_predicate,
335+
update_condition=update_condition,
315336
)
316337
return self.insert_and_get(**fields, using=using)
317338

318339
def bulk_upsert(
319340
self,
320341
conflict_target: ConflictTarget,
321342
rows: Iterable[Dict],
322-
index_predicate: str = None,
343+
index_predicate: Optional[Union[Expression, str]] = None,
323344
return_model: bool = False,
324345
using: Optional[str] = None,
346+
update_condition: Optional[Expression] = None,
325347
):
326348
"""Creates a set of new records or updates the existing ones with the
327349
specified data.
@@ -345,6 +367,9 @@ def bulk_upsert(
345367
The name of the database connection to use
346368
for this query.
347369
370+
update_condition:
371+
Only update if this SQL expression evaluates to true.
372+
348373
Returns:
349374
A list of either the dicts of the rows upserted, including the pk or
350375
the models of the rows upserted
@@ -357,7 +382,10 @@ def is_empty(r):
357382
return []
358383

359384
self.on_conflict(
360-
conflict_target, ConflictAction.UPDATE, index_predicate
385+
conflict_target,
386+
ConflictAction.UPDATE,
387+
index_predicate=index_predicate,
388+
update_condition=update_condition,
361389
)
362390
return self.bulk_insert(rows, return_model, using=using)
363391

@@ -425,6 +453,7 @@ def _build_insert_compiler(
425453
query = PostgresInsertQuery(self.model)
426454
query.conflict_action = self.conflict_action
427455
query.conflict_target = self.conflict_target
456+
query.conflict_update_condition = self.conflict_update_condition
428457
query.index_predicate = self.index_predicate
429458
query.values(objs, insert_fields, update_fields)
430459

psqlextra/sql.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,8 @@ def __init__(self, *args, **kwargs):
145145

146146
self.conflict_target = []
147147
self.conflict_action = ConflictAction.UPDATE
148+
self.conflict_update_condition = None
149+
self.index_predicate = None
148150

149151
self.update_fields = []
150152

pytest.ini

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,5 @@
22
DJANGO_SETTINGS_MODULE=settings
33
testpaths=tests
44
addopts=-m "not benchmark"
5+
filterwarnings =
6+
ignore::UserWarning

tests/test_upsert.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from django.db import models
2+
from django.db.models.expressions import CombinedExpression, Value
23

34
from psqlextra.fields import HStoreField
45

@@ -76,6 +77,35 @@ def test_upsert_explicit_pk():
7677
assert obj2.cookies == "second-boo"
7778

7879

80+
def test_upsert_with_update_condition():
81+
"""Tests that a custom expression can be passed as an update condition."""
82+
83+
model = get_fake_model(
84+
{
85+
"name": models.TextField(unique=True),
86+
"priority": models.IntegerField(),
87+
"active": models.BooleanField(),
88+
}
89+
)
90+
91+
obj1 = model.objects.create(name="joe", priority=1, active=False)
92+
93+
model.objects.upsert(
94+
conflict_target=["name"],
95+
update_condition=CombinedExpression(
96+
model._meta.get_field("active").get_col(model._meta.db_table),
97+
"=",
98+
Value(True),
99+
),
100+
fields=dict(name="joe", priority=2, active=True),
101+
)
102+
103+
obj1.refresh_from_db()
104+
105+
assert obj1.priority == 1
106+
assert not obj1.active
107+
108+
79109
def test_upsert_bulk():
80110
"""Tests whether bulk_upsert works properly."""
81111

0 commit comments

Comments
 (0)