Skip to content

Commit ea58192

Browse files
authored
Merge pull request #52 from lairdm/master
bulk_insert() wasn't returning any objects as documented.
2 parents 7efc21e + 3742fe9 commit ea58192

File tree

4 files changed

+49
-6
lines changed

4 files changed

+49
-6
lines changed

docs/manager.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ obj = (
209209
)
210210
```
211211

212-
`bulk_insert` uses a single query to insert all specified rows at once.
212+
`bulk_insert` uses a single query to insert all specified rows at once. It returns a `list` of `dict()` with each `dict()` being a merge of the `dict()` passed in along with any index returned from Postgres.
213213

214214
#### Limitations
215215
In order to stick to the "everything in one query" principle, various, more advanced usages of `bulk_insert` are impossible. It is not possible to have different rows specify different amounts of columns. The following example does **not work**:

psqlextra/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def execute_sql(self, return_id=False):
4848
rows = []
4949
for sql, params in self.as_sql(return_id):
5050
cursor.execute(sql, params)
51-
rows.append(cursor.fetchone())
51+
rows.extend(cursor.fetchall())
5252

5353
# create a mapping between column names and column value
5454
return [

psqlextra/manager/manager.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def on_conflict(self, fields: List[Union[str, Tuple[str]]], action, index_predic
139139

140140
return self
141141

142-
def bulk_insert(self, rows):
142+
def bulk_insert(self, rows, return_model=False):
143143
"""Creates multiple new records in the database.
144144
145145
This allows specifying custom conflict behavior using .on_conflict().
@@ -150,16 +150,25 @@ def bulk_insert(self, rows):
150150
An array of dictionaries, where each dictionary
151151
describes the fields to insert.
152152
153+
return_model (default: False):
154+
If model instances should be returned rather than
155+
just dicts.
156+
153157
Returns:
158+
A list of either the dicts of the rows inserted, including the pk or
159+
the models of the rows inserted with defaults for any fields not specified
154160
"""
155161

156162
if self.conflict_target or self.conflict_action:
157163
compiler = self._build_insert_compiler(rows)
158-
compiler.execute_sql(return_id=True)
159-
return
164+
objs = compiler.execute_sql(return_id=True)
165+
if return_model:
166+
return [self.model(**dict(r, **k)) for r, k in zip(rows, objs)]
167+
else:
168+
return [dict(r, **k) for r, k in zip(rows, objs)]
160169

161170
# no special action required, use the standard Django bulk_create(..)
162-
super().bulk_create([self.model(**fields) for fields in rows])
171+
return super().bulk_create([self.model(**fields) for fields in rows])
163172

164173
def insert(self, **fields):
165174
"""Creates a new record in the database.

tests/test_on_conflict.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,3 +413,37 @@ def test_on_conflict_bulk():
413413

414414
for index, obj in enumerate(list(model.objects.all())):
415415
assert obj.title == rows[index]['title']
416+
417+
def test_bulk_return():
418+
"""Tests if primary keys are properly returned from 'bulk_insert'
419+
"""
420+
421+
model = get_fake_model({
422+
'id': models.BigAutoField(primary_key=True),
423+
'name': models.CharField(max_length=255, unique=True)
424+
})
425+
426+
rows = [
427+
dict(name='John Smith'),
428+
dict(name='Jane Doe')
429+
]
430+
431+
objs = (
432+
model.objects
433+
.on_conflict(['name'], ConflictAction.UPDATE)
434+
.bulk_insert(rows)
435+
)
436+
437+
for index, obj in enumerate(objs, 1):
438+
assert obj['id'] == index
439+
440+
"""Add objects again, update should return the same ids
441+
as we're just updating."""
442+
objs = (
443+
model.objects
444+
.on_conflict(['name'], ConflictAction.UPDATE)
445+
.bulk_insert(rows)
446+
)
447+
448+
for index, obj in enumerate(objs, 1):
449+
assert obj['id'] == index

0 commit comments

Comments
 (0)