Skip to content

Commit 318fbc4

Browse files
author
Matthew Laird
committed
bulk_insert() in the documentation has:
obj = ( MyModel.objects .on_conflict(['name'], ConflictAction.UPDATE) .bulk_insert([ ... However bulk_insert() wasn't actually returning anything. What would be useful would be to utilize the 'returning' clause already being added to the Postgres query to send back any auto inc indexes the table might have. This makes the functionality better mirror bulk_create() In addition, when using the on_conflict() this would allow the user to see the indexes for any entry that might be updated. As well, to make this work there was a bug in how 'returning' clause rows were being retrieved in the compiler, it was only grabbing the first entry from the cursor, not all the returned records. Tests have been added to this patch as well as documentation.
1 parent c6686f1 commit 318fbc4

File tree

4 files changed

+49
-6
lines changed

4 files changed

+49
-6
lines changed

docs/manager.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ obj = (
209209
)
210210
```
211211

212-
`bulk_insert` uses a single query to insert all specified rows at once.
212+
`bulk_insert` uses a single query to insert all specified rows at once. It returns a `list` of `dict()` with each `dict()` being a merge of the `dict()` passed in along with any index returned from Postgres.
213213

214214
#### Limitations
215215
In order to stick to the "everything in one query" principle, various, more advanced usages of `bulk_insert` are impossible. It is not possible to have different rows specify different amounts of columns. The following example does **not work**:

psqlextra/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def execute_sql(self, return_id=False):
4848
rows = []
4949
for sql, params in self.as_sql(return_id):
5050
cursor.execute(sql, params)
51-
rows.append(cursor.fetchone())
51+
rows.extend(cursor.fetchall())
5252

5353
# create a mapping between column names and column value
5454
return [

psqlextra/manager/manager.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def on_conflict(self, fields: List[Union[str, Tuple[str]]], action, index_predic
139139

140140
return self
141141

142-
def bulk_insert(self, rows):
142+
def bulk_insert(self, rows, return_model=False):
143143
"""Creates multiple new records in the database.
144144
145145
This allows specifying custom conflict behavior using .on_conflict().
@@ -150,16 +150,25 @@ def bulk_insert(self, rows):
150150
An array of dictionaries, where each dictionary
151151
describes the fields to insert.
152152
153+
return_model (default: False):
154+
If model instances should be returned rather than
155+
just dicts.
156+
153157
Returns:
158+
A list of either the dicts of the rows inserted, including the pk or
159+
the models of the rows inserted with defaults for any fields not specified
154160
"""
155161

156162
if self.conflict_target or self.conflict_action:
157163
compiler = self._build_insert_compiler(rows)
158-
compiler.execute_sql(return_id=True)
159-
return
164+
objs = compiler.execute_sql(return_id=True)
165+
if return_model:
166+
return [ self.model(**dict(r, **k)) for r,k in zip(rows,objs) ]
167+
else:
168+
return [ dict(r, **k) for r,k in zip(rows,objs) ]
160169

161170
# no special action required, use the standard Django bulk_create(..)
162-
super().bulk_create([self.model(**fields) for fields in rows])
171+
return super().bulk_create([self.model(**fields) for fields in rows])
163172

164173
def insert(self, **fields):
165174
"""Creates a new record in the database.

tests/test_on_conflict.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,3 +413,37 @@ def test_on_conflict_bulk():
413413

414414
for index, obj in enumerate(list(model.objects.all())):
415415
assert obj.title == rows[index]['title']
416+
417+
def test_bulk_return():
418+
"""Tests if primary keys are properly returned from 'bulk_insert'
419+
"""
420+
421+
model = get_fake_model({
422+
'id': models.BigAutoField(primary_key=True),
423+
'name': models.CharField(max_length=255, unique=True)
424+
})
425+
426+
rows = [
427+
dict(name='John Smith'),
428+
dict(name='Jane Doe')
429+
]
430+
431+
objs = (
432+
model.objects
433+
.on_conflict(['name'], ConflictAction.UPDATE)
434+
.bulk_insert(rows)
435+
)
436+
437+
for index, obj in enumerate(objs, 1):
438+
assert obj['id'] == index
439+
440+
"""Add objects again, update should return the same ids
441+
as we're just updating."""
442+
objs = (
443+
model.objects
444+
.on_conflict(['name'], ConflictAction.UPDATE)
445+
.bulk_insert(rows)
446+
)
447+
448+
for index, obj in enumerate(objs, 1):
449+
assert obj['id'] == index

0 commit comments

Comments
 (0)