Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/manager.md
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ obj = (
)
```

`bulk_insert` uses a single query to insert all specified rows at once.
`bulk_insert` uses a single query to insert all specified rows at once. It returns a `list` of `dict()` with each `dict()` being a merge of the `dict()` passed in along with any index returned from Postgres.

#### Limitations
In order to stick to the "everything in one query" principle, various, more advanced usages of `bulk_insert` are impossible. It is not possible to have different rows specify different amounts of columns. The following example does **not work**:
Expand Down
2 changes: 1 addition & 1 deletion psqlextra/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def execute_sql(self, return_id=False):
rows = []
for sql, params in self.as_sql(return_id):
cursor.execute(sql, params)
rows.append(cursor.fetchone())
rows.extend(cursor.fetchall())

# create a mapping between column names and column value
return [
Expand Down
17 changes: 13 additions & 4 deletions psqlextra/manager/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def on_conflict(self, fields: List[Union[str, Tuple[str]]], action, index_predic

return self

def bulk_insert(self, rows):
def bulk_insert(self, rows, return_model=False):
"""Creates multiple new records in the database.

This allows specifying custom conflict behavior using .on_conflict().
Expand All @@ -150,16 +150,25 @@ def bulk_insert(self, rows):
An array of dictionaries, where each dictionary
describes the fields to insert.

return_model (default: False):
If model instances should be returned rather than
just dicts.

Returns:
A list of either the dicts of the rows inserted, including the pk or
the models of the rows inserted with defaults for any fields not specified
"""

if self.conflict_target or self.conflict_action:
compiler = self._build_insert_compiler(rows)
compiler.execute_sql(return_id=True)
return
objs = compiler.execute_sql(return_id=True)
if return_model:
return [self.model(**dict(r, **k)) for r, k in zip(rows, objs)]
else:
return [dict(r, **k) for r, k in zip(rows, objs)]

# no special action required, use the standard Django bulk_create(..)
super().bulk_create([self.model(**fields) for fields in rows])
return super().bulk_create([self.model(**fields) for fields in rows])

def insert(self, **fields):
"""Creates a new record in the database.
Expand Down
34 changes: 34 additions & 0 deletions tests/test_on_conflict.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,3 +413,37 @@ def test_on_conflict_bulk():

for index, obj in enumerate(list(model.objects.all())):
assert obj.title == rows[index]['title']

def test_bulk_return():
"""Tests if primary keys are properly returned from 'bulk_insert'
"""

model = get_fake_model({
'id': models.BigAutoField(primary_key=True),
'name': models.CharField(max_length=255, unique=True)
})

rows = [
dict(name='John Smith'),
dict(name='Jane Doe')
]

objs = (
model.objects
.on_conflict(['name'], ConflictAction.UPDATE)
.bulk_insert(rows)
)

for index, obj in enumerate(objs, 1):
assert obj['id'] == index

"""Add objects again, update should return the same ids
as we're just updating."""
objs = (
model.objects
.on_conflict(['name'], ConflictAction.UPDATE)
.bulk_insert(rows)
)

for index, obj in enumerate(objs, 1):
assert obj['id'] == index