diff --git a/docs/source/conflict_handling.rst b/docs/source/conflict_handling.rst index 9edc71b1..89d1a0cc 100644 --- a/docs/source/conflict_handling.rst +++ b/docs/source/conflict_handling.rst @@ -87,6 +87,41 @@ Specifying multiple columns is necessary in case of a constraint that spans mult ) +Specific constraint +******************* + +Alternatively, instead of specifying the columns the constraint you're targetting applies to, you can also specify the exact constraint to use: + +.. code-block:: python + + from django.db import models + from psqlextra.models import PostgresModel + + class MyModel(PostgresModel) + class Meta: + constraints = [ + models.UniqueConstraint( + name="myconstraint", + fields=["first_name", "last_name"] + ), + ] + + first_name = models.CharField(max_length=255) + last_name = models.CharField(max_length=255) + + constraint = next( + constraint + for constraint in MyModel._meta.constraints + if constraint.name == "myconstraint" + ), None) + + obj = ( + MyModel.objects + .on_conflict(constraint, ConflictAction.UPDATE) + .insert_and_get(first_name='Henk', last_name='Jansen') + ) + + HStore keys *********** Catching conflicts in columns with a ``UNIQUE`` constraint on a :class:`~psqlextra.fields.HStoreField` key is also supported: diff --git a/psqlextra/compiler.py b/psqlextra/compiler.py index 12fff3fa..88a65e9a 100644 --- a/psqlextra/compiler.py +++ b/psqlextra/compiler.py @@ -243,11 +243,11 @@ def _rewrite_insert_on_conflict( # build the conflict target, the columns to watch # for conflicts - conflict_target = self._build_conflict_target() + on_conflict_clause = self._build_on_conflict_clause() index_predicate = self.query.index_predicate # type: ignore[attr-defined] update_condition = self.query.conflict_update_condition # type: ignore[attr-defined] - rewritten_sql = f"{sql} ON CONFLICT {conflict_target}" + rewritten_sql = f"{sql} {on_conflict_clause}" if index_predicate: expr_sql, expr_params = self._compile_expression(index_predicate) @@ -270,6 +270,21 @@ def _rewrite_insert_on_conflict( return (rewritten_sql, params) + def _build_on_conflict_clause(self): + if django.VERSION >= (2, 2): + from django.db.models.constraints import BaseConstraint + from django.db.models.indexes import Index + + if isinstance( + self.query.conflict_target, BaseConstraint + ) or isinstance(self.query.conflict_target, Index): + return "ON CONFLICT ON CONSTRAINT %s" % self.qn( + self.query.conflict_target.name + ) + + conflict_target = self._build_conflict_target() + return f"ON CONFLICT {conflict_target}" + def _build_conflict_target(self): """Builds the `conflict_target` for the ON CONFLICT clause.""" diff --git a/psqlextra/query.py b/psqlextra/query.py index 5c5e6f47..b3feec1d 100644 --- a/psqlextra/query.py +++ b/psqlextra/query.py @@ -20,7 +20,11 @@ from .sql import PostgresInsertQuery, PostgresQuery from .types import ConflictAction -ConflictTarget = List[Union[str, Tuple[str]]] +if TYPE_CHECKING: + from django.db.models.constraints import BaseConstraint + from django.db.models.indexes import Index + +ConflictTarget = Union[List[Union[str, Tuple[str]]], "BaseConstraint", "Index"] TModel = TypeVar("TModel", bound=models.Model, covariant=True) diff --git a/tests/test_on_conflict_update.py b/tests/test_on_conflict_update.py index 8425e3d3..b93e5781 100644 --- a/tests/test_on_conflict_update.py +++ b/tests/test_on_conflict_update.py @@ -1,3 +1,4 @@ +import django import pytest from django.db import models @@ -41,6 +42,35 @@ def test_on_conflict_update(): assert obj2.cookies == "choco" +@pytest.mark.skipif( + django.VERSION < (2, 2), + reason="Django < 2.2 doesn't implement constraints", +) +def test_on_conflict_update_by_unique_constraint(): + model = get_fake_model( + { + "title": models.CharField(max_length=255, null=True), + }, + meta_options={ + "constraints": [ + models.UniqueConstraint(name="test_uniq", fields=["title"]), + ], + }, + ) + + constraint = next( + ( + constraint + for constraint in model._meta.constraints + if constraint.name == "test_uniq" + ) + ) + + model.objects.on_conflict(constraint, ConflictAction.UPDATE).insert_and_get( + title="title" + ) + + def test_on_conflict_update_foreign_key_by_object(): """Tests whether simple upsert works correctly when the conflicting field is a foreign key specified as an object."""